### Importing Libraries

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from transformers import BertTokenizer, BertModel


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

### Loading Dataser

In [3]:

def load_data(file_path):
    pairs = []
    with open(file_path, "r") as f:
        for line in f:
            if '\t' in line:
                input_text, target_text = line.strip().split("\t")
                pairs.append((input_text, target_text))
    return pairs


In [4]:
# Prepare Dataset and DataLoader
data_file = "dialogs.txt"
pairs = load_data(data_file)


In [5]:
# Step 1: Dataset Preparation
class DialogDataset(Dataset):
    def __init__(self, pairs, tokenizer, max_length):
        self.pairs = pairs
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.pairs)

    def __getitem__(self, idx):
        input_text, target_text = self.pairs[idx]
        input_enc = self.tokenizer(
            input_text,
            truncation=True,
            padding="max_length",
            max_length=self.max_length,
            return_tensors="pt",
        )
        target_enc = self.tokenizer(
            target_text,
            truncation=True,
            padding="max_length",
            max_length=self.max_length,
            return_tensors="pt",
        )
        return input_enc["input_ids"].squeeze(), input_enc["attention_mask"].squeeze(), target_enc["input_ids"].squeeze()

In [6]:
train_pairs = pairs
val_pairs = pairs[int(0.8 * len(pairs)):] 

max_length = 50
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")

train_dataset = DialogDataset(train_pairs, tokenizer, max_length)
val_dataset = DialogDataset(val_pairs, tokenizer, max_length)

train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=128)



In [9]:
class ChatTransformerWithCNN(nn.Module):
    def __init__(self, pretrained_model_name="bert-base-uncased", num_heads=8, num_encoder_layers=4, num_decoder_layers=4):
        super(ChatTransformerWithCNN, self).__init__()

        # Pretrained BERT model for encoding
        self.encoder = BertModel.from_pretrained(pretrained_model_name)

        # CNN Layers
        self.conv1 = nn.Conv1d(in_channels=768, out_channels=512, kernel_size=3, padding=1)
        self.conv2 = nn.Conv1d(in_channels=512, out_channels=256, kernel_size=3, padding=1)
        self.conv3 = nn.Conv1d(in_channels=256, out_channels=768, kernel_size=3, padding=1)

        # Multihead Attention Layer
        self.attention = nn.MultiheadAttention(embed_dim=768, num_heads=num_heads, batch_first=True)

        # Decoder (use TransformerDecoderLayer to create TransformerDecoder)
        decoder_layers = nn.TransformerDecoderLayer(d_model=768, nhead=num_heads)
        self.decoder = nn.TransformerDecoder(decoder_layers, num_layers=num_decoder_layers)

        # Output Layer (for generating logits)
        self.fc_out = nn.Linear(768, self.encoder.config.vocab_size)

        # Tokenizer to get vocab size for the output layer
        self.tokenizer = BertTokenizer.from_pretrained(pretrained_model_name)

    def forward(self, input_ids, attention_mask, target_ids=None):
        # Get encoder outputs from BERT model
        encoder_outputs = self.encoder(input_ids, attention_mask=attention_mask)
        encoder_hidden_states = encoder_outputs['last_hidden_state']

        # Apply CNN layers to encoder hidden states
        cnn_out = self.conv1(encoder_hidden_states.permute(0, 2, 1))
        cnn_out = nn.ReLU()(cnn_out)
        cnn_out = self.conv2(cnn_out)
        cnn_out = nn.ReLU()(cnn_out)
        cnn_out = self.conv3(cnn_out).permute(0, 2, 1)

        # Create the padding mask for attention
        attn_mask = (attention_mask == 0).unsqueeze(1).unsqueeze(2)  # Shape [batch_size, 1, 1, seq_len]

        # Apply Multihead Attention on CNN output
        attn_output, _ = self.attention(cnn_out, cnn_out, cnn_out, key_padding_mask=attn_mask.squeeze())

        # Create decoder input (if target_ids is provided)
        if target_ids is not None:
            target_embeddings = self.encoder.embeddings(target_ids)  # Embed the target_ids if available
            decoder_output = self.decoder(tgt=target_embeddings, memory=attn_output)
        else:
            # During inference, use only the attention output
            decoder_output = self.decoder(tgt=attn_output, memory=attn_output)

        # Final output (logits for token prediction)
        logits = self.fc_out(decoder_output)

        return logits


In [10]:
model = ChatTransformerWithCNN()
model.to(device)

ChatTransformerWithCNN(
  (encoder): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSdpaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12,

In [11]:
# Step 3: Training and Validation
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(),lr=0.0001)

In [12]:
def load_model(model, optimizer, path="model_checkpoint.pth"):
    checkpoint = torch.load(path)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    start_epoch = checkpoint['epoch']
    return model, optimizer, start_epoch

In [13]:
load_model(model, optimizer, path="model_checkpoint.pth")

RuntimeError: Error(s) in loading state_dict for ChatTransformerWithCNN:
	Missing key(s) in state_dict: "decoder.layers.0.self_attn.in_proj_weight", "decoder.layers.0.self_attn.in_proj_bias", "decoder.layers.0.self_attn.out_proj.weight", "decoder.layers.0.self_attn.out_proj.bias", "decoder.layers.0.multihead_attn.in_proj_weight", "decoder.layers.0.multihead_attn.in_proj_bias", "decoder.layers.0.multihead_attn.out_proj.weight", "decoder.layers.0.multihead_attn.out_proj.bias", "decoder.layers.0.linear1.weight", "decoder.layers.0.linear1.bias", "decoder.layers.0.linear2.weight", "decoder.layers.0.linear2.bias", "decoder.layers.0.norm1.weight", "decoder.layers.0.norm1.bias", "decoder.layers.0.norm2.weight", "decoder.layers.0.norm2.bias", "decoder.layers.0.norm3.weight", "decoder.layers.0.norm3.bias", "decoder.layers.1.self_attn.in_proj_weight", "decoder.layers.1.self_attn.in_proj_bias", "decoder.layers.1.self_attn.out_proj.weight", "decoder.layers.1.self_attn.out_proj.bias", "decoder.layers.1.multihead_attn.in_proj_weight", "decoder.layers.1.multihead_attn.in_proj_bias", "decoder.layers.1.multihead_attn.out_proj.weight", "decoder.layers.1.multihead_attn.out_proj.bias", "decoder.layers.1.linear1.weight", "decoder.layers.1.linear1.bias", "decoder.layers.1.linear2.weight", "decoder.layers.1.linear2.bias", "decoder.layers.1.norm1.weight", "decoder.layers.1.norm1.bias", "decoder.layers.1.norm2.weight", "decoder.layers.1.norm2.bias", "decoder.layers.1.norm3.weight", "decoder.layers.1.norm3.bias", "decoder.layers.2.self_attn.in_proj_weight", "decoder.layers.2.self_attn.in_proj_bias", "decoder.layers.2.self_attn.out_proj.weight", "decoder.layers.2.self_attn.out_proj.bias", "decoder.layers.2.multihead_attn.in_proj_weight", "decoder.layers.2.multihead_attn.in_proj_bias", "decoder.layers.2.multihead_attn.out_proj.weight", "decoder.layers.2.multihead_attn.out_proj.bias", "decoder.layers.2.linear1.weight", "decoder.layers.2.linear1.bias", "decoder.layers.2.linear2.weight", "decoder.layers.2.linear2.bias", "decoder.layers.2.norm1.weight", "decoder.layers.2.norm1.bias", "decoder.layers.2.norm2.weight", "decoder.layers.2.norm2.bias", "decoder.layers.2.norm3.weight", "decoder.layers.2.norm3.bias", "decoder.layers.3.self_attn.in_proj_weight", "decoder.layers.3.self_attn.in_proj_bias", "decoder.layers.3.self_attn.out_proj.weight", "decoder.layers.3.self_attn.out_proj.bias", "decoder.layers.3.multihead_attn.in_proj_weight", "decoder.layers.3.multihead_attn.in_proj_bias", "decoder.layers.3.multihead_attn.out_proj.weight", "decoder.layers.3.multihead_attn.out_proj.bias", "decoder.layers.3.linear1.weight", "decoder.layers.3.linear1.bias", "decoder.layers.3.linear2.weight", "decoder.layers.3.linear2.bias", "decoder.layers.3.norm1.weight", "decoder.layers.3.norm1.bias", "decoder.layers.3.norm2.weight", "decoder.layers.3.norm2.bias", "decoder.layers.3.norm3.weight", "decoder.layers.3.norm3.bias". 
	Unexpected key(s) in state_dict: "decoder.encoder.layers.0.self_attn.in_proj_weight", "decoder.encoder.layers.0.self_attn.in_proj_bias", "decoder.encoder.layers.0.self_attn.out_proj.weight", "decoder.encoder.layers.0.self_attn.out_proj.bias", "decoder.encoder.layers.0.linear1.weight", "decoder.encoder.layers.0.linear1.bias", "decoder.encoder.layers.0.linear2.weight", "decoder.encoder.layers.0.linear2.bias", "decoder.encoder.layers.0.norm1.weight", "decoder.encoder.layers.0.norm1.bias", "decoder.encoder.layers.0.norm2.weight", "decoder.encoder.layers.0.norm2.bias", "decoder.encoder.layers.1.self_attn.in_proj_weight", "decoder.encoder.layers.1.self_attn.in_proj_bias", "decoder.encoder.layers.1.self_attn.out_proj.weight", "decoder.encoder.layers.1.self_attn.out_proj.bias", "decoder.encoder.layers.1.linear1.weight", "decoder.encoder.layers.1.linear1.bias", "decoder.encoder.layers.1.linear2.weight", "decoder.encoder.layers.1.linear2.bias", "decoder.encoder.layers.1.norm1.weight", "decoder.encoder.layers.1.norm1.bias", "decoder.encoder.layers.1.norm2.weight", "decoder.encoder.layers.1.norm2.bias", "decoder.encoder.layers.2.self_attn.in_proj_weight", "decoder.encoder.layers.2.self_attn.in_proj_bias", "decoder.encoder.layers.2.self_attn.out_proj.weight", "decoder.encoder.layers.2.self_attn.out_proj.bias", "decoder.encoder.layers.2.linear1.weight", "decoder.encoder.layers.2.linear1.bias", "decoder.encoder.layers.2.linear2.weight", "decoder.encoder.layers.2.linear2.bias", "decoder.encoder.layers.2.norm1.weight", "decoder.encoder.layers.2.norm1.bias", "decoder.encoder.layers.2.norm2.weight", "decoder.encoder.layers.2.norm2.bias", "decoder.encoder.layers.3.self_attn.in_proj_weight", "decoder.encoder.layers.3.self_attn.in_proj_bias", "decoder.encoder.layers.3.self_attn.out_proj.weight", "decoder.encoder.layers.3.self_attn.out_proj.bias", "decoder.encoder.layers.3.linear1.weight", "decoder.encoder.layers.3.linear1.bias", "decoder.encoder.layers.3.linear2.weight", "decoder.encoder.layers.3.linear2.bias", "decoder.encoder.layers.3.norm1.weight", "decoder.encoder.layers.3.norm1.bias", "decoder.encoder.layers.3.norm2.weight", "decoder.encoder.layers.3.norm2.bias", "decoder.encoder.norm.weight", "decoder.encoder.norm.bias", "decoder.decoder.layers.0.self_attn.in_proj_weight", "decoder.decoder.layers.0.self_attn.in_proj_bias", "decoder.decoder.layers.0.self_attn.out_proj.weight", "decoder.decoder.layers.0.self_attn.out_proj.bias", "decoder.decoder.layers.0.multihead_attn.in_proj_weight", "decoder.decoder.layers.0.multihead_attn.in_proj_bias", "decoder.decoder.layers.0.multihead_attn.out_proj.weight", "decoder.decoder.layers.0.multihead_attn.out_proj.bias", "decoder.decoder.layers.0.linear1.weight", "decoder.decoder.layers.0.linear1.bias", "decoder.decoder.layers.0.linear2.weight", "decoder.decoder.layers.0.linear2.bias", "decoder.decoder.layers.0.norm1.weight", "decoder.decoder.layers.0.norm1.bias", "decoder.decoder.layers.0.norm2.weight", "decoder.decoder.layers.0.norm2.bias", "decoder.decoder.layers.0.norm3.weight", "decoder.decoder.layers.0.norm3.bias", "decoder.decoder.layers.1.self_attn.in_proj_weight", "decoder.decoder.layers.1.self_attn.in_proj_bias", "decoder.decoder.layers.1.self_attn.out_proj.weight", "decoder.decoder.layers.1.self_attn.out_proj.bias", "decoder.decoder.layers.1.multihead_attn.in_proj_weight", "decoder.decoder.layers.1.multihead_attn.in_proj_bias", "decoder.decoder.layers.1.multihead_attn.out_proj.weight", "decoder.decoder.layers.1.multihead_attn.out_proj.bias", "decoder.decoder.layers.1.linear1.weight", "decoder.decoder.layers.1.linear1.bias", "decoder.decoder.layers.1.linear2.weight", "decoder.decoder.layers.1.linear2.bias", "decoder.decoder.layers.1.norm1.weight", "decoder.decoder.layers.1.norm1.bias", "decoder.decoder.layers.1.norm2.weight", "decoder.decoder.layers.1.norm2.bias", "decoder.decoder.layers.1.norm3.weight", "decoder.decoder.layers.1.norm3.bias", "decoder.decoder.layers.2.self_attn.in_proj_weight", "decoder.decoder.layers.2.self_attn.in_proj_bias", "decoder.decoder.layers.2.self_attn.out_proj.weight", "decoder.decoder.layers.2.self_attn.out_proj.bias", "decoder.decoder.layers.2.multihead_attn.in_proj_weight", "decoder.decoder.layers.2.multihead_attn.in_proj_bias", "decoder.decoder.layers.2.multihead_attn.out_proj.weight", "decoder.decoder.layers.2.multihead_attn.out_proj.bias", "decoder.decoder.layers.2.linear1.weight", "decoder.decoder.layers.2.linear1.bias", "decoder.decoder.layers.2.linear2.weight", "decoder.decoder.layers.2.linear2.bias", "decoder.decoder.layers.2.norm1.weight", "decoder.decoder.layers.2.norm1.bias", "decoder.decoder.layers.2.norm2.weight", "decoder.decoder.layers.2.norm2.bias", "decoder.decoder.layers.2.norm3.weight", "decoder.decoder.layers.2.norm3.bias", "decoder.decoder.layers.3.self_attn.in_proj_weight", "decoder.decoder.layers.3.self_attn.in_proj_bias", "decoder.decoder.layers.3.self_attn.out_proj.weight", "decoder.decoder.layers.3.self_attn.out_proj.bias", "decoder.decoder.layers.3.multihead_attn.in_proj_weight", "decoder.decoder.layers.3.multihead_attn.in_proj_bias", "decoder.decoder.layers.3.multihead_attn.out_proj.weight", "decoder.decoder.layers.3.multihead_attn.out_proj.bias", "decoder.decoder.layers.3.linear1.weight", "decoder.decoder.layers.3.linear1.bias", "decoder.decoder.layers.3.linear2.weight", "decoder.decoder.layers.3.linear2.bias", "decoder.decoder.layers.3.norm1.weight", "decoder.decoder.layers.3.norm1.bias", "decoder.decoder.layers.3.norm2.weight", "decoder.decoder.layers.3.norm2.bias", "decoder.decoder.layers.3.norm3.weight", "decoder.decoder.layers.3.norm3.bias", "decoder.decoder.norm.weight", "decoder.decoder.norm.bias". 

In [14]:
def train_fn(loader, model, optimizer, criterion):
    model.train()
    total_loss = 0
    for input_ids, attention_mask, target_ids in loader:
        input_ids, attention_mask, target_ids = input_ids.to(device), attention_mask.to(device), target_ids.to(device)
        optimizer.zero_grad()

        logits = model(input_ids, attention_mask)  # Pass only the arguments used
        logits = logits.view(-1, logits.size(-1))
        target_ids = target_ids.view(-1)

        loss = criterion(logits, target_ids)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
    return total_loss / len(loader)



In [15]:
def eval_fn(loader, model, criterion):
    model.eval()
    total_loss = 0
    with torch.no_grad():
        for input_ids, attention_mask, target_ids in loader:
            input_ids, attention_mask, target_ids = input_ids.to(device), attention_mask.to(device), target_ids.to(device)

            logits = model(input_ids, attention_mask)  # Pass only the arguments used
            logits = logits.view(-1, logits.size(-1))
            target_ids = target_ids.view(-1)

            loss = criterion(logits, target_ids)
            total_loss += loss.item()
    return total_loss / len(loader)

In [16]:
epoch=0
for epoc in range(500):
    epoch=epoc
    train_loss = train_fn(train_loader, model, optimizer, criterion)

    print(f"Epoch {epoch + 1}, Train Loss: {train_loss}")

  attn_output = torch.nn.functional.scaled_dot_product_attention(


Epoch 1, Train Loss: 2.006572997570038
Epoch 2, Train Loss: 1.6251471757888794
Epoch 3, Train Loss: 1.6268082102139791
Epoch 4, Train Loss: 1.589842394987742
Epoch 5, Train Loss: 1.5636261661847433


KeyboardInterrupt: 

In [17]:
val_loss = eval_fn(val_loader, model, criterion)
print(f" Val Loss: {val_loss}")

 Val Loss: 1.5610515872637432


In [18]:
def save_model(model, optimizer, epoch, path="model_checkpoint.pth"):
    torch.save({
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'epoch': epoch
    }, path)


In [19]:
save_model(model, optimizer, epoch, path="model_checkpoint.pth")

In [50]:
def predict_response(input_text, max_length=50):
    model.eval()  # Set model to evaluation mode
    with torch.no_grad():
        # Tokenize the input text
        input_enc = model.tokenizer(
            input_text,
            return_tensors="pt",
            truncation=True,
            padding="max_length",
            max_length=max_length
        )
        input_ids = input_enc["input_ids"].to(device)
        attention_mask = input_enc["attention_mask"].to(device)

        print("Input Token IDs:", input_ids)
        print("Attention Mask:", attention_mask)
        print("Input Text Tokens:", model.tokenizer.convert_ids_to_tokens(input_ids.squeeze().tolist()))
        
        # Get encoder outputs from the BERT encoder
        encoder_outputs = model.encoder(input_ids, attention_mask=attention_mask)
        encoder_hidden_states = encoder_outputs['last_hidden_state']  # Shape: [batch_size, seq_len, embed_dim]
        print("Encoder Hidden States Shape:", encoder_hidden_states.shape)

        # Apply CNN layers to encoder hidden states
        cnn_out = model.conv1(encoder_hidden_states.permute(0, 2, 1))  # Shape: [batch_size, embed_dim, seq_len]
        cnn_out = nn.ReLU()(cnn_out)
        cnn_out = model.conv2(cnn_out)
        cnn_out = nn.ReLU()(cnn_out)
        cnn_out = model.conv3(cnn_out).permute(0, 2, 1)  # Shape: [batch_size, seq_len, embed_dim]
        print("CNN Output Shape:", cnn_out.shape)

        # Create the padding mask for attention
        attn_mask = (attention_mask == 0)  # Shape [batch_size, seq_len]

        # Apply Multihead Attention
        attn_output, _ = model.attention(cnn_out, cnn_out, cnn_out, key_padding_mask=attn_mask)
        print("Attention Output Shape:", attn_output.shape)

        # Initialize the decoder input (start token)
        decoder_input = input_ids[:, 0:1]  # Take the first token
        decoder_input_embedded = model.encoder.embeddings(decoder_input)  # Get its embedding
        print("Decoder Input Shape:", decoder_input_embedded.shape)

        # Generate tokens one by one
        generated_ids = decoder_input
        for _ in range(max_length - 1):
            decoder_output = model.decoder(tgt=decoder_input_embedded, memory=attn_output)

            # Get the logits from the decoder output
            logits = model.fc_out(decoder_output)  # Shape: [batch_size, seq_len, vocab_size]
            print("Logits Shape:", logits.shape)

            # Get the most probable token (argmax over vocab dimension)
            next_token_ids = torch.argmax(logits[:, -1, :], dim=-1).unsqueeze(-1)

            # Embed the new token for the next step
            decoder_input_embedded = model.encoder.embeddings(next_token_ids)

            # Append the predicted token to the sequence
            generated_ids = torch.cat((generated_ids, next_token_ids), dim=1)

        # Decode the generated token IDs into text
        generated_text = model.tokenizer.decode(generated_ids.squeeze().tolist(), skip_special_tokens=True)

        print("Generated Token IDs:", generated_ids)
        print("Generated Text:", generated_text)

        return generated_text


In [51]:
predict_response("no problem. so how have you been?")

Input Token IDs: tensor([[ 101, 2053, 3291, 1012, 2061, 2129, 2031, 2017, 2042, 1029,  102,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0]], device='cuda:0')
Attention Mask: tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0]], device='cuda:0')
Input Text Tokens: ['[CLS]', 'no', 'problem', '.', 'so', 'how', 'have', 'you', 'been', '?', '[SEP]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', 

RuntimeError: shape '[1, 8, 96]' is invalid for input of size 38400