In [16]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from collections import Counter
import nltk
from nltk.tokenize import word_tokenize
import re
import string
import math
from torch.optim.lr_scheduler import LambdaLR
!pip install nltk
from collections import Counter
import nltk
nltk.download('punkt')
nltk.download('punkt_tab')
nltk.download('stopwords') # Added to download stopwords
import re



[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Unzipping tokenizers/punkt.zip.
[nltk_data] Downloading package punkt_tab to /root/nltk_data...
[nltk_data]   Unzipping tokenizers/punkt_tab.zip.
[nltk_data] Downloading package stopwords to /root/nltk_data...
[nltk_data]   Unzipping corpora/stopwords.zip.


In [2]:

class TextPreprocessor:
    def __init__(self, text):
        self.text = text
        self.slang_dict = {
            "AFAIK": "As Far As I Know",
            "AFK": "Away From Keyboard",
            "ASAP": "As Soon As Possible",
            "ATK": "At The Keyboard",
            "ATM": "At The Moment",
            "A3": "Anytime, Anywhere, Anyplace",
            "BAK": "Back At Keyboard",
            "BBL": "Be Back Later",
            "BBS": "Be Back Soon",
            "BFN": "Bye For Now",
            "B4N": "Bye For Now",
            "BRB": "Be Right Back",
            "BRT": "Be Right There",
            "BTW": "By The Way",
            "B4": "Before",
            "CU": "See You",
            "CUL8R": "See You Later",
            "CYA": "See You",
            "FAQ": "Frequently Asked Questions",
            "FC": "Fingers Crossed",
            "FWIW": "For What It's Worth",
            "FYI": "For Your Information",
            "GAL": "Get A Life",
            "GG": "Good Game",
            "GN": "Good Night",
            "GMTA": "Great Minds Think Alike",
            "GR8": "Great!",
            "G9": "Genius",
            "IC": "I See",
            "ICQ": "I Seek you (also a chat program)",
            "ILU": "I Love You",
            "IMHO": "In My Honest/Humble Opinion",
            "IMO": "In My Opinion",
            "IOW": "In Other Words",
            "IRL": "In Real Life",
            "KISS": "Keep It Simple, Stupid",
            "LDR": "Long Distance Relationship",
            "LMAO": "Laugh My A** Off",
            "LOL": "Laughing Out Loud",
            "LTNS": "Long Time No See",
            "L8R": "Later",
            "MTE": "My Thoughts Exactly",
            "M8": "Mate",
            "NRN": "No Reply Necessary",
            "OIC": "Oh I See",
            "PITA": "Pain In The A**",
            "PRT": "Party",
            "PRW": "Parents Are Watching",
            "QPSA?": "Que Pasa?",
            "ROFL": "Rolling On The Floor Laughing",
            "ROFLOL": "Rolling On The Floor Laughing Out Loud",
            "ROTFLMAO": "Rolling On The Floor Laughing My A** Off",
            "SK8": "Skate",
            "STATS": "Your sex and age",
            "ASL": "Age, Sex, Location",
            "THX": "Thank You",
            "TTFN": "Ta-Ta For Now!",
            "TTYL": "Talk To You Later",
            "U": "You",
            "U2": "You Too",
            "U4E": "Yours For Ever",
            "WB": "Welcome Back",
            "WTF": "What The F...",
            "WTG": "Way To Go!",
            "WUF": "Where Are You From?",
            "W8": "Wait...",
            "7K": "Sick:-D Laugher",
            "TFW": "That Feeling When",
            "MFW": "My Face When",
            "MRW": "My Reaction When",
            "IFYP": "I Feel Your Pain",
            "TNTL": "Trying Not To Laugh",
            "JK": "Just Kidding",
            "IDC": "I Don't Care",
            "ILY": "I Love You",
            "IMU": "I Miss You",
            "ADIH": "Another Day In Hell",
            "ZZZ": "Sleeping, Bored, Tired",
            "WYWH": "Wish You Were Here",
            "TIME": "Tears In My Eyes",
            "BAE": "Before Anyone Else",
            "FIMH": "Forever In My Heart",
            "BSAAW": "Big Smile And A Wink",
            "BWL": "Bursting With Laughter",
            "BFF": "Best Friends Forever",
            "CSL": "Can't Stop Laughing"
        }

    def to_lower(self):
        self.text = self.text.lower()

    def remove_urls(self):
        url_pattern = r'https?://\S+|www\.\S+'
        self.text = re.sub(url_pattern, '', self.text)

    def remove_punc(self):
        exclude = string.punctuation
        self.text = self.text.translate(str.maketrans('', '', exclude))

    def remove_selective_punc(self):
        """Remove all punctuation except brackets, semicolons, angle brackets, and equals sign using regex"""
        self.text = re.sub(r'[^\w\s()\[\]{}]', '', self.text)
        return self.text

    def remove_numbers(self):
        self.text = re.sub(r'\d+', '', self.text)

    def remove_abbr(self):
        words = self.text.split(" ")
        processed_words = []
        for word in words:
            if word.upper() in self.slang_dict:
                processed_words.append(self.slang_dict[word.upper()])
            else:
                processed_words.append(word)
        self.text = " ".join(processed_words)

    def remove_stopwords(self):
        stopwords = nltk.corpus.stopwords.words('english')
        words = self.text.split()
        filtered_words = [word for word in words if word.lower() not in stopwords]
        self.text = ' '.join(filtered_words)

    def clean_whitespace(self):
        # Remove extra whitespace and clean up
        self.text = ' '.join(self.text.split())

    def gettext(self):
        return self.text

    def process_text(self):
        # self.to_lower()
        self.remove_urls()
        self.remove_selective_punc()
        # self.remove_numbers()
        self.remove_abbr()
        # self.remove_stopwords()
        self.clean_whitespace()
        return self.text

In [3]:
class Vocabulary:
    def __init__(self, text, min_freq=3, tokenizer=word_tokenize):
        self.tokenizer = tokenizer
        self.min_freq = min_freq
        self.word2index = {'<PAD>': 0, '<UNK>': 1, '<SOS>': 2, '<EOS>': 3}  # Add special tokens
        self.index2word = {0: '<PAD>', 1: '<UNK>', 2: '<SOS>', 3: '<EOS>'}
        self.tokens = []
        self.numerical_text = []
        self.text = text
        self.word_freq = {}
        self._vocab_built = False

    def build_vocabulary(self):
        if isinstance(self.text, str):
            self.tokens = [self.tokenizer(self.text)]
        else:
            self.tokens = [self.tokenizer(t) for t in self.text]


        for token_list in self.tokens:
            for word in token_list:
                self.word_freq[word] = self.word_freq.get(word, 0) + 1


        index = 4
        for word, freq in self.word_freq.items():
            if freq >= self.min_freq:
                self.word2index[word] = index
                self.index2word[index] = word
                index += 1

        self._vocab_built = True
        print(f"Vocabulary built with {len(self.word2index)} tokens")
        print(f"Most frequent words: {sorted(self.word_freq.items(), key=lambda x: x[1], reverse=True)[:10]}")
        return self.word2index

    def vectorize(self):
        if not self._vocab_built:
            self.build_vocabulary()

        self.numerical_text = [
            [self.word2index.get(token, 1) for token in tokens]  # Use 1 for <UNK>
            for tokens in self.tokens
        ]
        return self.numerical_text

    def devectorize(self, num_text):
        if not self._vocab_built:
            raise ValueError("Vocabulary not built yet. Call build_vocabulary() first.")
        words = list([self.index2word.get(idx, '<UNK>') for idx in num_text])
        return " ".join(words)

    def zero_pad(self, max_len=None):
        if not self.numerical_text:
            self.vectorize()

        if max_len is None:
            max_len = max(len(seq) for seq in self.numerical_text)

        padded = [
            seq + [0] * (max_len - len(seq)) if len(seq) < max_len
            else seq[:max_len]
            for seq in self.numerical_text
        ]
        return padded

    def get_vocab_size(self):
        return len(self.word2index)



In [4]:
def get_batch(split: str, train_data, val_data, batch_size=32, context_size=16):
    """Improved batch generation with better error handling and bounds checking"""
    data = train_data if split == 'train' else val_data

    if len(data) <= context_size:
        print(f"Warning: Data length ({len(data)}) <= context_size ({context_size})")
        return None, None


    max_start_index = len(data) - context_size - 1
    if max_start_index <= 0:
        print(f"Error: Not enough data for context_size {context_size}")
        return None, None

    starts = torch.randint(0, max_start_index, (batch_size,))

    x = torch.stack([data[i:i+context_size] for i in starts])
    y = torch.stack([data[i+1:i+context_size+1] for i in starts])


    vocab_size = train_data.max().item() + 1 if len(train_data) > 0 else 1000
    x = torch.clamp(x, 0, vocab_size - 1)
    y = torch.clamp(y, 0, vocab_size - 1)


    if (x == 0).all():
        print("Warning: All-zero input batch detected!")
    if (y == 0).all():
        print("Warning: All-zero target batch detected!")

    return x, y

In [5]:
class TransformerBlock(nn.Module):
    def __init__(self, embed_size, num_heads, ff_size=256*2, dropout=0.1, device='cpu'):
        super(TransformerBlock, self).__init__()
        self.attention = nn.MultiheadAttention(embed_dim=embed_size, num_heads=num_heads, dropout=dropout, batch_first=True)
        self.ff = nn.Sequential(
            nn.Linear(embed_size, ff_size),
            nn.GELU(),
            nn.Linear(ff_size, embed_size),
            nn.Dropout(dropout)
        )
        self.norm1 = nn.LayerNorm(embed_size)
        self.norm2 = nn.LayerNorm(embed_size)
        self.dropout = nn.Dropout(dropout)
        self.device = device

    def forward(self, x, attn_mask=None):
        x = x.to(self.device)

        # Pre-norm architecture
        residual = x
        x = self.norm1(x)

        # Self-attention with batch_first=True
        attention, _ = self.attention(x, x, x, attn_mask=attn_mask, need_weights=False)
        x = self.dropout(attention) + residual

        # Feed forward
        residual = x
        x = self.norm2(x)
        x = self.ff(x)
        x = self.dropout(x) + residual

        return x


In [6]:
class DisasterGPT(nn.Module):
    def __init__(self, vocab_size, embed_size=512, num_heads=8, num_layers=3, context_size=128, dropout=0.1, device='cpu'):
        super().__init__()
        self.vocab_size = vocab_size
        self.embed_size = embed_size
        self.num_heads = num_heads
        self.num_layers = num_layers
        self.context_size = context_size
        self.device = device

        self.dropout = nn.Dropout(dropout)

        # CRITICAL FIX: Ensure embedding handles vocab_size correctly
        self.embedding = nn.Embedding(vocab_size, embed_size, padding_idx=0)
        self.positional_encoding = nn.Parameter(torch.zeros(1, context_size, embed_size), requires_grad=True)

        ff_size = 4 * embed_size
        self.transformer_blocks = nn.ModuleList([
            TransformerBlock(embed_size, num_heads, ff_size, dropout, device) for _ in range(num_layers)
        ])
        self.fc_out = nn.Linear(embed_size, vocab_size)
        self.ln_final = nn.LayerNorm(embed_size)

        self.init_weights()

    def init_weights(self):
        """Proper weight initialization to prevent NaN"""
        for module in self.modules():
            if isinstance(module, nn.Linear):
                nn.init.normal_(module.weight, mean=0.0, std=0.02)
                if module.bias is not None:
                    nn.init.zeros_(module.bias)
            elif isinstance(module, nn.Embedding):
                nn.init.normal_(module.weight, mean=0.0, std=0.02)
                # Zero out padding token embedding
                if hasattr(self, 'embedding') and module == self.embedding:
                    with torch.no_grad():
                        module.weight[0].fill_(0)
            elif isinstance(module, nn.LayerNorm):
                nn.init.ones_(module.weight)
                nn.init.zeros_(module.bias)

    def generate_causal_mask(self, size):
        """Generate causal mask for self-attention"""
        mask = torch.triu(torch.ones(size, size), diagonal=1).bool()
        return mask.to(self.device)

    def forward(self, x):

        x = torch.clamp(x, 0, self.vocab_size - 1)

        B, T = x.size()


        token_emb = self.embedding(x)  # Shape: (B, T, embed_size)
        token_emb = self.dropout(token_emb)


        if T > self.positional_encoding.size(1):
            # Handle sequences longer than context_size
            pos_emb = self.positional_encoding[:, :1, :].repeat(1, T, 1)
        else:
            pos_emb = self.positional_encoding[:, :T, :]

        pos_emb = pos_emb.to(self.device)


        x = token_emb + pos_emb
        x = x.to(self.device)

        # Generate causal mask
        causal_mask = self.generate_causal_mask(T)

        # Transformer blocks
        for block in self.transformer_blocks:
            x = block(x, attn_mask=causal_mask)

        # Final layer norm and output projection
        x = self.ln_final(x)
        x = self.dropout(x)
        x = self.fc_out(x)

        return x

In [7]:
def get_cosine_schedule_with_warmup(optimizer, warmup_steps, total_steps):
    def lr_lambda(current_step):
        if current_step < warmup_steps:
            return float(current_step) / float(max(1, warmup_steps))
        progress = float(current_step - warmup_steps) / float(max(1, total_steps - warmup_steps))
        return max(0.0, 0.5 * (1.0 + math.cos(math.pi * progress)))

    return LambdaLR(optimizer, lr_lambda)

In [8]:
def setup_training(model, lr=1e-4, device='cpu', warmup=500, total_steps=10000):
    """Setup training with improved hyperparameters"""
    optimizer = torch.optim.AdamW(
        model.parameters(),
        lr=lr,
        betas=(0.9, 0.95),
        weight_decay=0.1,
        eps=1e-8
    )

    # Use label smoothing to prevent overconfidence
    loss_fn = nn.CrossEntropyLoss(label_smoothing=0.1, ignore_index=0)

    scheduler = get_cosine_schedule_with_warmup(optimizer, warmup, total_steps=total_steps)
    return optimizer, loss_fn, scheduler

In [9]:
def train(model, optimizer, loss_fn, scheduler, train_data, val_data, device='cpu', epochs=10,
          steps_per_epoch=500, patience=15, min_delta=1e-4, context_size=128, batch_size=32):
    """
    Train the model with early stopping and NaN detection
    """
    print(f"Moving model to device: {device}")
    model = model.to(device)
    model.train()

    best_val_loss = float('inf')
    patience_counter = 0
    training_history = {'train_loss': [], 'train_accuracy': [], 'val_loss': [], 'val_accuracy': []}

    for epoch in range(epochs):
        total_loss = 0
        train_correct = 0
        train_total = 0
        valid_steps = 0

        for step in range(steps_per_epoch):
            try:
                xb, yb = get_batch('train', train_data=train_data, val_data=val_data,
                                 batch_size=batch_size, context_size=context_size)

                if xb is None or yb is None:
                    continue

                xb, yb = xb.to(device), yb.to(device)

                optimizer.zero_grad()
                logits = model(xb)
                loss = loss_fn(logits.view(-1, logits.size(-1)), yb.view(-1))

                # Check for NaN loss
                if torch.isnan(loss) or torch.isinf(loss):
                    print(f"NaN/Inf loss detected at epoch {epoch+1}, step {step+1}")
                    continue

                loss.backward()

                # Gradient clipping
                grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

                # Check for NaN gradients
                nan_grads = False
                for param in model.parameters():
                    if param.grad is not None and (torch.isnan(param.grad).any() or torch.isinf(param.grad).any()):
                        nan_grads = True
                        break

                if nan_grads:
                    print(f"NaN gradients detected at epoch {epoch+1}, step {step+1}")
                    optimizer.zero_grad()
                    continue

                optimizer.step()
                scheduler.step()

                # Calculate training accuracy
                with torch.no_grad():
                    train_preds = torch.argmax(logits, dim=-1)
                    # Only count non-padding tokens
                    mask = (yb != 0)
                    train_correct += ((train_preds == yb) & mask).sum().item()
                    train_total += mask.sum().item()

                total_loss += loss.item()
                valid_steps += 1

            except Exception as e:
                print(f"Error in training step {step}: {e}")
                continue

        if valid_steps == 0:
            print(f"No valid training steps in epoch {epoch+1}")
            break

        avg_loss = total_loss / valid_steps
        train_accuracy = train_correct / train_total if train_total > 0 else 0.0

        # Validation
        model.eval()
        val_total_loss = 0
        val_correct = 0
        val_total = 0
        valid_val_batches = 0

        with torch.no_grad():
            try:
                val_steps = min(50, len(val_data) // (batch_size * context_size))
                for val_step in range(max(1, val_steps)):
                    xv, yv = get_batch('val', train_data, val_data,
                                     batch_size=batch_size, context_size=context_size)

                    if xv is None or yv is None:
                        continue

                    xv, yv = xv.to(device), yv.to(device)
                    val_logits = model(xv)
                    val_loss = loss_fn(val_logits.view(-1, val_logits.size(-1)), yv.view(-1))

                    if torch.isnan(val_loss) or torch.isinf(val_loss):
                        continue

                    val_total_loss += val_loss.item()
                    valid_val_batches += 1

                    preds = torch.argmax(val_logits, dim=-1)
                    mask = (yv != 0)
                    val_correct += ((preds == yv) & mask).sum().item()
                    val_total += mask.sum().item()

                if valid_val_batches > 0:
                    avg_val_loss = val_total_loss / valid_val_batches
                    val_accuracy = val_correct / val_total if val_total > 0 else 0.0
                else:
                    avg_val_loss = float('inf')
                    val_accuracy = 0.0

            except Exception as e:
                print(f"Error in validation: {e}")
                avg_val_loss = float('inf')
                val_accuracy = 0.0

        model.train()

        # Store training history
        training_history['train_loss'].append(avg_loss)
        training_history['train_accuracy'].append(train_accuracy)
        training_history['val_loss'].append(avg_val_loss)
        training_history['val_accuracy'].append(val_accuracy)

        print(f"Epoch {epoch+1}/{epochs} | Train Loss: {avg_loss:.4f} | Train Acc: {train_accuracy:.4f} | Val Loss: {avg_val_loss:.4f} | Val Acc: {val_accuracy:.4f}")

        # Early stopping
        if not (torch.isnan(torch.tensor(avg_val_loss)) or torch.isinf(torch.tensor(avg_val_loss))):
            if avg_val_loss < best_val_loss - min_delta:
                best_val_loss = avg_val_loss
                patience_counter = 0
                # Save best model
                try:
                    torch.save(model.state_dict(), 'best_model.pth')
                except Exception as e:
                    print(f"Error saving model: {e}")
            else:
                patience_counter += 1
        else:
            patience_counter += 1

        if patience_counter >= patience:
            print(f"Early stopping at epoch {epoch+1}")
            break

    return training_history



In [37]:
# def generate_text(model, start_text, vocab_builder, max_length=50, device='cpu'):
#     model.eval()
#     tokens = vocab_builder.tokenizer(start_text)
#     input_ids = torch.tensor([vocab_builder.word2index.get(token, 1) for token in tokens]).unsqueeze(0).to(device)  # Add batch dimension
#     generated_text = start_text
#     with torch.no_grad():
#         for _ in range(max_length):
#             input_ids = input_ids.to(device)
#             logits = model(input_ids)
#             next_token_logits = logits[:, -1, :]  # Get logits for the last token
#             next_token_id = torch.argmax(next_token_logits, dim=-1).item()  # Get the predicted token ID

#             if next_token_id == 0:  # If <PAD> token is predicted, stop generation
#                 break

#             generated_text += ' ' + vocab_builder.index2word.get(next_token_id, '<UNK>')
#             input_ids = torch.cat([input_ids, torch.tensor([[next_token_id]], device=device)], dim=1)  # Append predicted token
#     return generated_text.strip()
def generate_text(model, start_text, vocab_builder, max_length=50, device='cpu', temperature=0.8):
    """Generate text with temperature sampling"""
    model.eval()
    tokens = vocab_builder.tokenizer(start_text)
    input_ids = torch.tensor([vocab_builder.word2index.get(token, 1) for token in tokens]).unsqueeze(0).to(device)
    generated_text = start_text

    with torch.no_grad():
        for _ in range(max_length):
            input_ids = input_ids.to(device)

            # Ensure input_ids are within valid range
            input_ids = torch.clamp(input_ids, 0, vocab_builder.get_vocab_size() - 1)

            logits = model(input_ids)
            next_token_logits = logits[:, -1, :] / temperature

            # Apply softmax and sample
            probs = F.softmax(next_token_logits, dim=-1)
            next_token_id = torch.multinomial(probs, 1).item()

            if next_token_id == 0 or next_token_id == 3:  # Stop on <PAD> or <EOS>
                break

            next_word = vocab_builder.index2word.get(next_token_id, '<UNK>')
            generated_text += ' ' + next_word

            # Append predicted token
            input_ids = torch.cat([input_ids, torch.tensor([[next_token_id]], device=device)], dim=1)

            # Truncate if too long to prevent memory issues
            if input_ids.size(1) > model.context_size:
                input_ids = input_ids[:, -model.context_size:]

    return generated_text.strip()


In [11]:
with open("/content/bookcorpus.txt","r") as file:
  text_data=file.read()

In [13]:
print("Preprocessing text...")
text_preprocessor = TextPreprocessor(text_data)
processed_text = text_preprocessor.process_text()


Preprocessing text...


In [17]:
vocab_builder = Vocabulary(processed_text, min_freq=3)
vocab_builder.build_vocabulary()

Vocabulary built with 23126 tokens
Most frequent words: [('SOS', 173131), ('EOS', 173131), ('the', 133942), ('to', 64348), ('and', 63800), ('a', 52518), ('of', 48494), ('i', 46303), ('he', 34845), ('was', 34052)]


{'<PAD>': 0,
 '<UNK>': 1,
 '<SOS>': 2,
 '<EOS>': 3,
 'SOS': 4,
 'the': 5,
 'halfling': 6,
 'book': 7,
 'one': 8,
 'in': 9,
 'fall': 10,
 'of': 11,
 'igneeria': 12,
 'series': 13,
 'copyright': 14,
 '2013': 15,
 'all': 16,
 'rights': 17,
 'reserved': 18,
 'EOS': 19,
 'isbn': 20,
 'for': 21,
 'my': 22,
 'family': 23,
 'who': 24,
 'encouraged': 25,
 'me': 26,
 'to': 27,
 'never': 28,
 'stop': 29,
 'fighting': 30,
 'dreams': 31,
 'chapter': 32,
 '1': 33,
 'summer': 34,
 'vacations': 35,
 'supposed': 36,
 'be': 37,
 'fun': 38,
 'right': 39,
 'i': 40,
 'wish': 41,
 'had': 42,
 'a': 43,
 'better': 44,
 'answer': 45,
 'that': 46,
 'question': 47,
 'new': 48,
 'york': 49,
 'is': 50,
 'not': 51,
 'place': 52,
 'youd': 53,
 'expect': 54,
 'much': 55,
 'happen': 56,
 'its': 57,
 'small': 58,
 'quiet': 59,
 'town': 60,
 'kind': 61,
 'where': 62,
 'everyone': 63,
 'knows': 64,
 'your': 65,
 'name': 66,
 'parents': 67,
 'wouldnt': 68,
 'even': 69,
 'care': 70,
 'if': 71,
 'you': 72,
 'stayed': 73,
 '

In [18]:
numeric_text = vocab_builder.vectorize()
data = torch.tensor(numeric_text[0])
data = torch.clamp(data, 0, vocab_builder.get_vocab_size() - 1)
data = data.flatten()

In [19]:
train_split = 0.8
split_idx = int(len(data) * train_split)
train_data = data[:split_idx]
val_data = data[split_idx:]

In [21]:
device='cuda' if torch.cuda.is_available() else 'cpu'

In [22]:
model = DisasterGPT(
        vocab_size=vocab_builder.get_vocab_size(),
        embed_size=256,  # Reduced for stability
        num_heads=8,
        num_layers=3,
        context_size=128,
        dropout=0.1,
        device=device
    )


In [25]:
vocab_builder.get_vocab_size()

23126

In [39]:
optimizer, loss_fn, scheduler = setup_training(model, lr=1e-4, device=device)
print("Starting training...")
history = train(
        model=model,
        optimizer=optimizer,
        loss_fn=loss_fn,
        scheduler=scheduler,
        train_data=train_data,
        val_data=val_data,
        device=device,
        epochs=2,
        steps_per_epoch=500,
        context_size=256,
        batch_size=64  # Reduced for stability
    )

Starting training...
Moving model to device: cuda


KeyboardInterrupt: 

In [55]:
def get_text(prompt):

    device = 'cuda' if torch.cuda.is_available() else 'cpu'

    # Instantiate the model first
    loaded_model = DisasterGPT(
        vocab_size=vocab_builder.get_vocab_size(),
        embed_size=256,
        num_heads=8,
        num_layers=3,
        context_size=128,
        dropout=0.1,
        device=device
    )

    # Load the state dictionary
    loaded_mode=torch.load("/content/finetuned_model.pth", weights_only=False)
    loaded_model.eval()
    loaded_model.to(device)
    generated_text = generate_text(loaded_model, prompt, vocab_builder, max_length=50, device=device)
    return generated_text

In [34]:
!pip install fvcore



In [57]:
from fvcore.nn import FlopCountAnalysis, parameter_count_table
import torch
import torch._dynamo
torch._dynamo.reset()

# Define input shape: (batch_size, sequence_length)
batch_size = 64
context_size = 256

# Create a dummy input
dummy_input = torch.randint(0, model.vocab_size, (batch_size, context_size)).to(model.device)

# Compute parameter count
print(parameter_count_table(model))

# Compute FLOPs
flops = FlopCountAnalysis(model, dummy_input)
print(f"FLOPs: {flops.total():,}")

transformer_blocks.0.attention.out_proj, transformer_blocks.1.attention.out_proj, transformer_blocks.2.attention.out_proj


| name                              | #elements or shape   |
|:----------------------------------|:---------------------|
| model                             | 14.3M                |
|  positional_encoding              |  (1, 128, 256)       |
|  embedding                        |  5.9M                |
|   embedding.weight                |   (23126, 256)       |
|  transformer_blocks               |  2.4M                |
|   transformer_blocks.0            |   0.8M               |
|    transformer_blocks.0.attention |    0.3M              |
|    transformer_blocks.0.ff        |    0.5M              |
|    transformer_blocks.0.norm1     |    0.5K              |
|    transformer_blocks.0.norm2     |    0.5K              |
|   transformer_blocks.1            |   0.8M               |
|    transformer_blocks.1.attention |    0.3M              |
|    transformer_blocks.1.ff        |    0.5M              |
|    transformer_blocks.1.norm1     |    0.5K              |
|    transformer_blocks.

In [40]:
with open("/content/finetuning.txt", "r") as file:
  finetuning_text=file.read()

In [42]:
def finetuning(model,text):
  finetuning_text_preprocessor=TextPreprocessor(text)
  text=finetuning_text_preprocessor.process_text()
  ft_tokens=vocab_builder.tokenizer(text)
  ft_numerical=[vocab_builder.word2index.get(token, 1) for token in ft_tokens]
  ft_data = torch.tensor(ft_numerical)
  ft_data = torch.clamp(ft_data, 0, vocab_builder.get_vocab_size() - 1)
  ft_train_split = 0.8
  ft_split_idx = int(len(ft_data) * ft_train_split)
  ft_train_data = ft_data[:ft_split_idx]
  ft_val_data = ft_data[ft_split_idx:]
  device='cuda' if torch.cuda.is_available() else 'cpu'
  optimizer, loss_fn, scheduler = setup_training(model, lr=1e-4, device=device)
  history=train(
        model=model,
        optimizer=optimizer,
        loss_fn=loss_fn,
        scheduler=scheduler,
        train_data=ft_train_data,
        val_data=ft_val_data,
        device=device,
        epochs=5,
        steps_per_epoch=100,
        context_size=256,
        batch_size=32 # Reduced for stability
    )
  return model







In [43]:
model=finetuning(model, finetuning_text)

Moving model to device: cuda
Epoch 1/5 | Train Loss: 5.6736 | Train Acc: 0.2531 | Val Loss: 6.0577 | Val Acc: 0.2604
Epoch 2/5 | Train Loss: 4.7778 | Train Acc: 0.3538 | Val Loss: 5.8509 | Val Acc: 0.2692
Epoch 3/5 | Train Loss: 4.4790 | Train Acc: 0.3715 | Val Loss: 5.7924 | Val Acc: 0.2791
Epoch 4/5 | Train Loss: 4.2487 | Train Acc: 0.3887 | Val Loss: 5.7796 | Val Acc: 0.2726
Epoch 5/5 | Train Loss: 4.0410 | Train Acc: 0.4091 | Val Loss: 6.1833 | Val Acc: 0.2539


In [54]:
torch.save(model, 'finetuned_model.pth')

In [56]:
print(get_text("My house is burning?"))

My house is burning? bee blankets costume malachi retrieved wages jfk twill tomato anticipation showering dias earpiece speaks promoting copious analyze mint posted cut passageway hotels amateur funniest gouged extending receptions quantity navigate drain cancel bloodless doganoglu dangerous command wickedly poetic mayday birthday orleans obsessed cheering openly mercias formation octave client ghastly deigned trance
