### Importing Libraries

In [20]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from transformers import BertTokenizer, BertModel


In [21]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

### Loading Dataser

In [22]:

def load_data(file_path):
    pairs = []
    with open(file_path, "r") as f:
        for line in f:
            if '\t' in line:
                input_text, target_text = line.strip().split("\t")
                pairs.append((input_text, target_text))
    return pairs


In [23]:
# Prepare Dataset and DataLoader
data_file = "dialogs.txt"
pairs = load_data(data_file)


In [24]:
# Step 1: Dataset Preparation
class DialogDataset(Dataset):
    def __init__(self, pairs, tokenizer, max_length):
        self.pairs = pairs
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.pairs)

    def __getitem__(self, idx):
        input_text, target_text = self.pairs[idx]
        input_enc = self.tokenizer(
            input_text,
            truncation=True,
            padding="max_length",
            max_length=self.max_length,
            return_tensors="pt",
        )
        target_enc = self.tokenizer(
            target_text,
            truncation=True,
            padding="max_length",
            max_length=self.max_length,
            return_tensors="pt",
        )
        return input_enc["input_ids"].squeeze(), input_enc["attention_mask"].squeeze(), target_enc["input_ids"].squeeze()

In [25]:
train_pairs = pairs
val_pairs = pairs[int(0.8 * len(pairs)):] 

max_length = 50
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")

train_dataset = DialogDataset(train_pairs, tokenizer, max_length)
val_dataset = DialogDataset(val_pairs, tokenizer, max_length)

train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=128)



In [26]:

# Step 2: Transformer Model with Attention and CNN Layers
class ChatTransformerWithCNN(nn.Module):
    def __init__(self):
        super(ChatTransformerWithCNN, self).__init__()
        self.encoder = BertModel.from_pretrained("bert-base-uncased")
        
        # CNN Layers
        self.conv1 = nn.Conv1d(in_channels=768, out_channels=512, kernel_size=3, padding=1)
        self.conv2 = nn.Conv1d(in_channels=512, out_channels=256, kernel_size=3, padding=1)
        self.conv3 = nn.Conv1d(in_channels=256, out_channels=768, kernel_size=3, padding=1)

        # Attention Layer
        self.attention = nn.MultiheadAttention(embed_dim=768, num_heads=8, batch_first=True)
        
        # Decoder and Output
        self.decoder = nn.Transformer(
            d_model=768, num_encoder_layers=4, num_decoder_layers=4
        )
        self.fc_out = nn.Linear(768, tokenizer.vocab_size)

    def forward(self, input_ids, attention_mask, target_ids=None):
        encoder_outputs = self.encoder(input_ids, attention_mask=attention_mask)["last_hidden_state"]
    
    # Apply CNN layers
        cnn_out = self.conv1(encoder_outputs.permute(0, 2, 1))
        cnn_out = nn.ReLU()(cnn_out)
        cnn_out = self.conv2(cnn_out)
        cnn_out = nn.ReLU()(cnn_out)
        cnn_out = self.conv3(cnn_out).permute(0, 2, 1)

    # Apply attention layer
        attn_output, _ = self.attention(cnn_out, cnn_out, cnn_out, key_padding_mask=~attention_mask.bool())

    # Decoder
        decoder_outputs = self.decoder(
            attn_output.permute(1, 0, 2), attn_output.permute(1, 0, 2)
        )
        logits = self.fc_out(decoder_outputs)
        return logits


In [27]:
model = ChatTransformerWithCNN()
model.to(device)

ChatTransformerWithCNN(
  (encoder): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSdpaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12,

In [28]:
# Step 3: Training and Validation
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(),lr=0.0001)

In [29]:
def load_model(model, optimizer, path="model_checkpoint.pth"):
    checkpoint = torch.load(path)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    start_epoch = checkpoint['epoch']
    return model, optimizer, start_epoch

In [30]:
load_model(model, optimizer, path="model_checkpoint.pth")

(ChatTransformerWithCNN(
   (encoder): BertModel(
     (embeddings): BertEmbeddings(
       (word_embeddings): Embedding(30522, 768, padding_idx=0)
       (position_embeddings): Embedding(512, 768)
       (token_type_embeddings): Embedding(2, 768)
       (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
       (dropout): Dropout(p=0.1, inplace=False)
     )
     (encoder): BertEncoder(
       (layer): ModuleList(
         (0-11): 12 x BertLayer(
           (attention): BertAttention(
             (self): BertSdpaSelfAttention(
               (query): Linear(in_features=768, out_features=768, bias=True)
               (key): Linear(in_features=768, out_features=768, bias=True)
               (value): Linear(in_features=768, out_features=768, bias=True)
               (dropout): Dropout(p=0.1, inplace=False)
             )
             (output): BertSelfOutput(
               (dense): Linear(in_features=768, out_features=768, bias=True)
               (LayerNorm): LayerN

In [31]:
def train_fn(loader, model, optimizer, criterion):
    model.train()
    total_loss = 0
    for input_ids, attention_mask, target_ids in loader:
        input_ids, attention_mask, target_ids = input_ids.to(device), attention_mask.to(device), target_ids.to(device)
        optimizer.zero_grad()

        logits = model(input_ids, attention_mask)  # Pass only the arguments used
        logits = logits.view(-1, logits.size(-1))
        target_ids = target_ids.view(-1)

        loss = criterion(logits, target_ids)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
    return total_loss / len(loader)



In [32]:
def eval_fn(loader, model, criterion):
    model.eval()
    total_loss = 0
    with torch.no_grad():
        for input_ids, attention_mask, target_ids in loader:
            input_ids, attention_mask, target_ids = input_ids.to(device), attention_mask.to(device), target_ids.to(device)

            logits = model(input_ids, attention_mask)  # Pass only the arguments used
            logits = logits.view(-1, logits.size(-1))
            target_ids = target_ids.view(-1)

            loss = criterion(logits, target_ids)
            total_loss += loss.item()
    return total_loss / len(loader)

In [33]:
epoch=0
for epoc in range(500):
    epoch=epoc
    train_loss = train_fn(train_loader, model, optimizer, criterion)

    print(f"Epoch {epoch + 1}, Train Loss: {train_loss}")

Epoch 1, Train Loss: 1.6279468615849813


KeyboardInterrupt: 

In [34]:
val_loss = eval_fn(val_loader, model, criterion)
print(f" Val Loss: {val_loss}")

 Val Loss: 1.7199497818946838


In [35]:
def save_model(model, optimizer, epoch, path="model_checkpoint.pth"):
    torch.save({
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'epoch': epoch
    }, path)


In [36]:
save_model(model, optimizer, epoch, path="model_checkpoint.pth")

In [41]:
def predict_response(input_text):
    model.eval()
    with torch.no_grad():
        input_enc = tokenizer(
            input_text,
            return_tensors="pt",
            truncation=True,
            padding="max_length",
            max_length=max_length
        )
        input_ids = input_enc["input_ids"].to(device)
        attention_mask = input_enc["attention_mask"].to(device)
        
        print("Input Token IDs:", input_ids)
        print("Attention Mask:", attention_mask)
        print("Input Text Tokens:", tokenizer.convert_ids_to_tokens(input_ids.squeeze().tolist()))
        outputs = model(input_ids, attention_mask, input_ids)  # Using input_ids as target_ids for inference
        logits = outputs.view(-1, outputs.size(-1))
        print("Logits Shape:", logits.shape)
        print("Logits (Sample):", logits[:5])

        predicted_ids = torch.argmax(logits, dim=-1)
      
        print("Predicted Token IDs:", predicted_ids)

        return tokenizer.decode(predicted_ids, skip_special_tokens=True)

In [42]:
predict_response("no problem. so how have you been?")

Input Token IDs: tensor([[ 101, 2053, 3291, 1012, 2061, 2129, 2031, 2017, 2042, 1029,  102,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0]], device='cuda:0')
Attention Mask: tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0]], device='cuda:0')
Input Text Tokens: ['[CLS]', 'no', 'problem', '.', 'so', 'how', 'have', 'you', 'been', '?', '[SEP]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', 

''