In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
# List the contents of your models folder to confirm the file is there
!ls /content/drive/MyDrive/MyModel/

ckpt_epoch40.pt


In [3]:
import torch
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
checkpoint_path = '/content/drive/MyDrive/MyModel/ckpt_epoch40.pt'

In [4]:
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("albert-base-v2")

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json:   0%|          | 0.00/25.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/684 [00:00<?, ?B/s]

spiece.model:   0%|          | 0.00/760k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.31M [00:00<?, ?B/s]

In [5]:
import torch
import torch.nn as nn
vocab_size = tokenizer.vocab_size
d_model = 256
n_heads = 4
n = 4
head_size = d_model//n_heads
max_len=512
dropout_rate = 0.2
# We pass max_len to the constructor
class Embedding(nn.Module):
    def __init__(self, vocab_size, d_model, max_len=512):
        super().__init__()
        self.token_embedding = nn.Embedding(vocab_size, d_model)
        # Use max_len here
        self.positional_embedding = nn.Embedding(max_len, d_model)

    def forward(self, x):
        # Get B and T from the input tensor x INSIDE the method
        B, T = x.shape
        token_embedded = self.token_embedding(x)
        # Use T to create the positions tensor
        positional_embedded = self.positional_embedding(torch.arange(T, device=x.device))
        return token_embedded + positional_embedded

In [6]:
class SingleHeadSelfAttention(nn.Module):
    def __init__(self,d_model,head_size):
        super().__init__()
        self.head_size = head_size
        self.d_model = d_model
        self.query = nn.Linear(d_model,head_size,bias = False)
        self.key = nn.Linear(d_model,head_size,bias = False)
        self.value = nn.Linear(d_model,head_size,bias = False)
    def forward(self,x,mask = None):
        q = self.query(x)
        k = self.key(x)
        v = self.value(x)
        attention_score = torch.matmul(q,k.transpose(-2,-1))/(self.head_size**0.5)
        if mask is not None:
            mask = mask.unsqueeze(1)
            attention_score = attention_score.masked_fill(mask==0,float('-inf'))
        output = torch.matmul(torch.softmax(attention_score,dim=-1),v)
        return output

In [7]:
class MultiHeadAttention(nn.Module):
    def __init__(self,n_heads,d_model,head_size):
        super().__init__()
        self.d_model = d_model
        self.head_size = head_size
        self.n_heads = n_heads
        self.multiheads = nn.ModuleList([SingleHeadSelfAttention(d_model,head_size) for head in range(n_heads)])
        self.projection_layer = nn.Linear(d_model,d_model)
    def forward(self,x,mask=None):
        output = torch.cat([h(x,mask = mask) for h in self.multiheads],dim = -1)
        output = self.projection_layer(output)
        return output

In [8]:
class MaskedSingleHeadAttention(nn.Module):
    def __init__(self,d_model, head_size, max_len):
        super().__init__()
        self.head_size = head_size
        self.query = nn.Linear(d_model,head_size,bias = False)
        self.key = nn.Linear(d_model,head_size,bias = False)
        self.value = nn.Linear(d_model,head_size,bias = False)
        # We create a large mask once and register it as a buffer.
        self.register_buffer('tril', torch.tril(torch.ones(max_len, max_len)))
    def forward(self,x):
        B,T,d_model = x.shape
        q = self.query(x)
        k = self.key(x)
        v = self.value(x)
        attention_score = torch.matmul(q,k.transpose(-2,-1))/(self.head_size**0.5)
        # We slice the pre-computed buffer to match the current sequence length T.
        masked_attention_score = attention_score.masked_fill(self.tril[:T, :T] == 0, float('-inf'))
        masked_attention_score = torch.softmax(masked_attention_score,dim=-1)
        return masked_attention_score@v

In [9]:
class MaskedMultiHeadAttention(nn.Module):
    def __init__(self,n_heads,d_model,head_size,max_len):
        super().__init__()
        self.d_model = d_model
        self.head_size = head_size
        self.n_heads = n_heads
        self.multiheads = nn.ModuleList([MaskedSingleHeadAttention(d_model,head_size,max_len) for head in range(n_heads)])
        self.projection_layer = nn.Linear(d_model,d_model)
    def forward(self,x):
        output = torch.cat([h(x) for h in self.multiheads],dim = -1)
        output = self.projection_layer(output)
        return output

In [10]:
import torch.nn.functional as F
class CrossAttention(nn.Module):
    def __init__(self,d_model,head_size):
        super().__init__()
        self.head_size = head_size
        self.query = nn.Linear(d_model,head_size,bias = False)
        self.key = nn.Linear(d_model,head_size,bias = False)
        self.value = nn.Linear(d_model,head_size,bias = False)
    def forward(self,encoder_output,masked_attention,mask = None):
        q = self.query(masked_attention)
        k = self.key(encoder_output)
        v = self.value(encoder_output)
        cross_attention_score = q@k.transpose(-2,-1)/(self.head_size**0.5)
        if mask is not None:
            mask = mask.unsqueeze(1)
            cross_attention_score = cross_attention_score.masked_fill(mask ==0,float('-inf'))
        output = F.softmax(cross_attention_score,dim=-1)@v
        return output

In [11]:
class MultiHeadCrossAttention(nn.Module):
    def __init__(self,n_heads,d_model,head_size):
        super().__init__()
        self.d_model = d_model
        self.head_size = head_size
        self.n_heads = n_heads
        self.multiheads = nn.ModuleList([CrossAttention(d_model,head_size) for head in range(n_heads)])
        self.projection_layer = nn.Linear(d_model,d_model)
    def forward(self,encoder_output,masked_attention,mask=None):
        output = torch.cat([h(encoder_output,masked_attention,mask=mask) for h in self.multiheads],dim = -1)
        output = self.projection_layer(output)
        return output

In [12]:
class PositionWiseFeedForward(nn.Module):
    def __init__(self,d_model):
        super().__init__()
        self.feed_forward = nn.Sequential(
            nn.Linear(d_model,4*d_model),
            nn.ReLU(),
            nn.Linear(4*d_model,d_model)
        )
    def forward(self,x):
        output = self.feed_forward(x)
        return output

In [13]:
class EncoderBlock(nn.Module):
    def __init__(self,n_heads,d_model,head_size,dropout_rate):
        super().__init__()
        self.d_model = d_model
        self.head_size = head_size
        self.n_heads = n_heads
        self.dropout_rate = dropout_rate
        self.multi_head_self_att = MultiHeadAttention(self.n_heads,self.d_model,self.head_size)
        self.ffd = PositionWiseFeedForward(self.d_model)
        self.ln1 = nn.LayerNorm(self.d_model)
        self.ln2 = nn.LayerNorm(self.d_model)
        self.dropout = nn.Dropout(self.dropout_rate)
    def forward(self,x,mask = None):
        x = self.ln1(x + self.dropout(self.multi_head_self_att(x,mask = mask)))
        output = self.ln2(x + self.dropout(self.ffd(x)))
        return output

In [14]:
class Encoder(nn.Module):
    def __init__(self,n,n_heads,d_model,head_size,dropout_rate):
        super().__init__()
        self.n = n
        self.d_model = d_model
        self.head_size = head_size
        self.n_heads = n_heads
        self.dropout_rate = dropout_rate
        self.blocks = nn.ModuleList([EncoderBlock(self.n_heads,self.d_model,self.head_size,self.dropout_rate) for _ in range(self.n)])
    def forward(self,x,mask = None):
        for block in self.blocks:
            x = block(x,mask = mask)
        return x

In [15]:
class DecoderBlock(nn.Module):
    def __init__(self,n_heads,d_model,head_size,max_len,dropout_rate):
        super().__init__()
        self.n_heads = n_heads
        self.d_model = d_model
        self.head_size = head_size
        self.max_len = max_len
        self.dropout_rate = dropout_rate
        self.masked_att = MaskedMultiHeadAttention(n_heads,d_model,head_size,max_len)
        self.ffd = PositionWiseFeedForward(d_model)
        self.cross_att = MultiHeadCrossAttention(n_heads,d_model,head_size)
        # self.encoder = Encoder(n,n_heads,d_model,head_size,max_len)
        self.ln1 = nn.LayerNorm(d_model)
        self.ln2 = nn.LayerNorm(d_model)
        self.ln3 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout_rate)
    def forward(self,encoder_output,decoder_input,mask = None):
        masked_attention = self.ln1(decoder_input + self.dropout(self.masked_att(decoder_input)))
        cross_attention = self.ln2(masked_attention + self.dropout(self.cross_att(encoder_output,masked_attention,mask = mask)))
        ffd_output = self.ln3(cross_attention + self.dropout(self.ffd(cross_attention)))
        return ffd_output

In [16]:
class Decoder(nn.Module):
    def __init__(self,n,n_heads,d_model,head_size,max_len,dropout_rate):
        super().__init__()
        self.n = n
        self.n_heads = n_heads
        self.d_model = d_model
        self.head_size = head_size
        self.max_len = max_len
        self.dropout_rate = dropout_rate
        self.blocks = nn.ModuleList([DecoderBlock(n_heads,d_model,head_size,max_len,dropout_rate) for i in range(n)])
    def forward(self,encoder_output,decoder_input,mask = None):
        for block in self.blocks:
            decoder_input = block(encoder_output,decoder_input,mask = mask)
        output = decoder_input
        return output

In [17]:
bos_idx = tokenizer.bos_token_id  # BOS stands for "Beginning Of Sequence"
eos_idx = tokenizer.eos_token_id
class Transformer(nn.Module):
    def __init__(self,n,n_heads,d_model,head_size,max_len,dropout_rate,vocab_size):
        super().__init__()
        self.n = n
        self.n_heads = n_heads
        self.d_model = d_model
        self.head_size = head_size
        self.max_len = max_len
        self.dropout_rate = dropout_rate
        self.embedding = Embedding(vocab_size, d_model, max_len)
        #self.embedding2 = Embedding(vocab_size, d_model, max_len)
        self.encoder = Encoder(n,n_heads,d_model,head_size,dropout_rate)
        self.decoder = Decoder(n,n_heads,d_model,head_size,max_len,dropout_rate)
        self.final_layer = nn.Linear(d_model,vocab_size)
    def forward(self,dialogue,summary,mask = None):
        embedded_dialogue = self.embedding(dialogue)
        embedded_summary = self.embedding(summary)
        encoder_output = self.encoder(embedded_dialogue,mask = mask)
        decoder_output = self.decoder(encoder_output,embedded_summary, mask = mask)
        logits = self.final_layer(decoder_output)
        return logits


In [18]:
model = Transformer(n,n_heads,d_model,head_size,max_len,dropout_rate,vocab_size)

In [19]:
state_dict = torch.load(checkpoint_path, map_location=device)
# Load only the model's state dictionary from the checkpoint
model.load_state_dict(state_dict['model_state_dict'])
model.to(device)
print("✅ Model weights successfully loaded from checkpoint!")

✅ Model weights successfully loaded from checkpoint!


In [20]:
from datasets import load_dataset
dataset = load_dataset("knkarthick/samsum")
print(dataset)

README.md: 0.00B [00:00, ?B/s]

train.csv: 0.00B [00:00, ?B/s]

validation.csv: 0.00B [00:00, ?B/s]

test.csv: 0.00B [00:00, ?B/s]

Generating train split:   0%|          | 0/14732 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/818 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/819 [00:00<?, ? examples/s]

DatasetDict({
    train: Dataset({
        features: ['id', 'dialogue', 'summary'],
        num_rows: 14732
    })
    validation: Dataset({
        features: ['id', 'dialogue', 'summary'],
        num_rows: 818
    })
    test: Dataset({
        features: ['id', 'dialogue', 'summary'],
        num_rows: 819
    })
})


In [21]:
from torch.utils.data import Dataset,DataLoader
max_len = 512
summary_len = 128
batch_size = 16
class SummarizationDataset(Dataset):
    def __init__(self,data,tokenizer,max_len,summary_len):
        self.data = data
        self.tokenizer = tokenizer
        self.max_len = max_len
        self.summary_len = summary_len
    def __len__(self):
        return len(self.data)
    def __getitem__(self,index):
        item = self.data[index]
        dialogue = str(item['dialogue'])
        summary = str(item['summary'])

        encoder_inputs = tokenizer(dialogue,
                                  max_length = self.max_len,
                                  padding = 'max_length',
                                  truncation = True,
                                  return_tensors = 'pt'
        )
        decoder_inputs = tokenizer(tokenizer.bos_token+summary,
                                  max_length = self.summary_len,
                                  padding = 'max_length',
                                  truncation = True,
                                  return_tensors = 'pt')
        labels = tokenizer(summary+tokenizer.eos_token,
                           max_length = self.summary_len,
                           padding = 'max_length',
                           truncation = True,
                           return_tensors = 'pt')
        return {
            "encoder_input_ids": encoder_inputs['input_ids'].flatten(),
            "encoder_attention_mask": encoder_inputs['attention_mask'].flatten(),
            "decoder_input_ids": decoder_inputs['input_ids'].flatten(),
            "labels": labels['input_ids'].flatten()
        }

train_dataset = SummarizationDataset(dataset['train'],tokenizer,max_len = max_len,summary_len = summary_len)
train_loader = DataLoader(train_dataset,batch_size = batch_size,shuffle = True)

valid_dataset = SummarizationDataset(dataset['validation'],tokenizer,max_len = max_len,summary_len = summary_len)
validation_loader = DataLoader(valid_dataset,batch_size = batch_size,shuffle = False)

In [None]:
import os
import torch
import torch.nn as nn
from torch.optim import AdamW
import time


TOTAL_EPOCHS = 5                
LEARNING_RATE = 3e-5            


PRETRAIN_CKPT_PATH = None

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

pad_id = tokenizer.pad_token_id
criterion = nn.CrossEntropyLoss(ignore_index=pad_id, label_smoothing=0.1)
optimizer = AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=0.01)

start_epoch = 0
if PRETRAIN_CKPT_PATH is not None:
    if os.path.isfile(PRETRAIN_CKPT_PATH):
        print(f"Loading pretrained checkpoint from: {PRETRAIN_CKPT_PATH}")
        checkpoint = torch.load(PRETRAIN_CKPT_PATH, map_location=device)
        #
        if "model_state_dict" in checkpoint:
            model.load_state_dict(checkpoint['model_state_dict'])
        else:
            model.load_state_dict(checkpoint)
        if "optimizer_state_dict" in checkpoint:
            try:
                optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
                print(" Optimizer state loaded (if compatible).")
            except Exception:
                print(" Optimizer state could not be loaded (shape mismatch/optimizer differs).")
        start_epoch = checkpoint.get('epoch', 0)
    else:
        print(f" Provided PRETRAIN_CKPT_PATH not found: {PRETRAIN_CKPT_PATH}. Starting from scratch.")
else:
    print(" No pretrained checkpoint path provided. Starting fine-tuning from current model weights.")



print(f"--- Starting fine-tuning for {TOTAL_EPOCHS} epochs ---")
for epoch in range(start_epoch, TOTAL_EPOCHS):
    start_time = time.time()

    
    model.train()
    total_train_loss = 0
    for batch in train_loader:
        dialogue = batch['encoder_input_ids'].to(device)
        summary = batch['decoder_input_ids'].to(device)
        labels = batch['labels'].to(device)
        dialogue_mask = batch['encoder_attention_mask'].to(device)

        optimizer.zero_grad()
        logits = model(dialogue, summary, mask=dialogue_mask)
        loss = criterion(logits.view(-1, logits.size(-1)), labels.view(-1))
        loss.backward()
        optimizer.step()

        total_train_loss += loss.item()

    # --- Validation Phase ---
    model.eval()
    total_valid_loss = 0
    with torch.no_grad():
        for batch in validation_loader:
            dialogue = batch['encoder_input_ids'].to(device)
            summary = batch['decoder_input_ids'].to(device)
            labels = batch['labels'].to(device)
            dialogue_mask = batch['encoder_attention_mask'].to(device)

            logits = model(dialogue, summary, mask=dialogue_mask)
            loss = criterion(logits.view(-1, logits.size(-1)), labels.view(-1))
            total_valid_loss += loss.item()

    end_time = time.time()
    epoch_duration = end_time - start_time

    
    avg_train_loss = total_train_loss / len(train_loader)
    avg_valid_loss = total_valid_loss / len(validation_loader)

    print("-" * 60)
    print(f"Epoch: {epoch + 1:02}/{TOTAL_EPOCHS} | Time: {epoch_duration:.2f}s")
    print(f"\tTrain Loss: {avg_train_loss:.4f}")
    print(f"\t Val. Loss: {avg_valid_loss:.4f}")

    

print("-" * 60)
print("--- Fine-tuning complete ---")

final_ckpt = {
    "epoch": TOTAL_EPOCHS,
    "model_state_dict": model.state_dict(),
    "optimizer_state_dict": optimizer.state_dict(),
}
FINAL_SAVE_PATH = "/content/best_finetuned.pt"   # Colab content
torch.save(final_ckpt, FINAL_SAVE_PATH)
print(f"🔖 Saved final checkpoint to: {FINAL_SAVE_PATH}")


🚀 No pretrained checkpoint path provided. Starting fine-tuning from current model weights.
--- Starting fine-tuning for 5 epochs ---
------------------------------------------------------------
Epoch: 01/5 | Time: 185.00s
	Train Loss: 3.8918
	 Val. Loss: 3.5319
------------------------------------------------------------
Epoch: 02/5 | Time: 187.27s
	Train Loss: 3.5285
	 Val. Loss: 3.4692
------------------------------------------------------------
Epoch: 03/5 | Time: 187.65s
	Train Loss: 3.4086
	 Val. Loss: 3.4351
------------------------------------------------------------
Epoch: 04/5 | Time: 188.15s
	Train Loss: 3.3185
	 Val. Loss: 3.4077
------------------------------------------------------------
Epoch: 05/5 | Time: 190.10s
	Train Loss: 3.2441
	 Val. Loss: 3.4034
------------------------------------------------------------
--- Fine-tuning complete ---
🔖 Saved final checkpoint to: /content/best_finetuned.pt


In [None]:
# --- SAVE FINAL MODEL WEIGHTS ---
torch.save(model.state_dict(), "/content/best_finetuned.pt")
print(" Saved final model weights to /content/best_finetuned.pt")

✅ Saved final model weights to /content/best_finetuned.pt


In [26]:
import torch

# 1) Re-create your model architecture exactly as before
model = Transformer(
    n=n,
    n_heads=n_heads,
    d_model=d_model,
    head_size=head_size,
    max_len=max_len,
    dropout_rate=dropout_rate,
    vocab_size=vocab_size
)

# 2) Load the saved state dict
checkpoint_path = "/content/best_finetuned.pt"
state_dict = torch.load(checkpoint_path, map_location=device)
model.load_state_dict(state_dict)

# 3) Switch to eval mode & move to device
model.to(device)
model.eval()

Transformer(
  (embedding): Embedding(
    (token_embedding): Embedding(30000, 256)
    (positional_embedding): Embedding(512, 256)
  )
  (encoder): Encoder(
    (blocks): ModuleList(
      (0-3): 4 x EncoderBlock(
        (multi_head_self_att): MultiHeadAttention(
          (multiheads): ModuleList(
            (0-3): 4 x SingleHeadSelfAttention(
              (query): Linear(in_features=256, out_features=64, bias=False)
              (key): Linear(in_features=256, out_features=64, bias=False)
              (value): Linear(in_features=256, out_features=64, bias=False)
            )
          )
          (projection_layer): Linear(in_features=256, out_features=256, bias=True)
        )
        (ffd): PositionWiseFeedForward(
          (feed_forward): Sequential(
            (0): Linear(in_features=256, out_features=1024, bias=True)
            (1): ReLU()
            (2): Linear(in_features=1024, out_features=256, bias=True)
          )
        )
        (ln1): LayerNorm((256,), eps=1e

In [74]:
test_dialogue = dataset['test']['dialogue'][6]
test_summary = dataset['test']['summary'][6]
print(type(test_dialogue))
print(type(test_summary))
print("#### the dialogue")
print(test_dialogue)
print("#### the summary")
print(test_summary)

<class 'str'>
<class 'str'>
#### the dialogue
Max: Know any good sites to buy clothes from?
Payton: Sure :) <file_other> <file_other> <file_other> <file_other> <file_other> <file_other> <file_other>
Max: That's a lot of them!
Payton: Yeah, but they have different things so I usually buy things from 2 or 3 of them.
Max: I'll check them out. Thanks. 
Payton: No problem :)
Max: How about u?
Payton: What about me?
Max: Do u like shopping?
Payton: Yes and no.
Max: How come?
Payton: I like browsing, trying on, looking in the mirror and seeing how I look, but not always buying.
Max: Y not?
Payton: Isn't it obvious? ;)
Max: Sry ;)
Payton: If I bought everything I liked, I'd have nothing left to live on ;)
Max: Same here, but probably different category ;)
Payton: Lol
Max: So what do u usually buy?
Payton: Well, I have 2 things I must struggle to resist!
Max: Which are?
Payton: Clothes, ofc ;)
Max: Right. And the second one?
Payton: Books. I absolutely love reading!
Max: Gr8! What books do u re

In [None]:
@torch.no_grad()
def generate(model: Transformer, dialogue: str, max_len: int = 200) -> str:
    tok_out = tokenizer(dialogue, return_tensors="pt").to(device)
    encoder_ids = tok_out["input_ids"]
    output_ids = torch.tensor([tokenizer.bos_token_id], device=device).unsqueeze(0)

    for _ in range(max_len):
        logits = model(encoder_ids, output_ids)
        next_logits = logits[:, -1, :]
        probs = nn.functional.softmax(next_logits, dim=-1)
        next_id = torch.multinomial(probs, num_samples=1)
        output_ids = torch.cat([output_ids, next_id], dim=-1)
        if next_id.item() == eos_idx:
            break

    return tokenizer.decode(output_ids.squeeze(), skip_special_tokens=True)


summary = generate(model, test_dialogue)
print(summary)


payton will buy clothes from payton's place and buy clothes from goodbye to him.
