In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from transformers import BertTokenizer, BertModel
import math
import torch.nn.functional as F
from torch.nn.utils.rnn import pad_sequence


  from .autonotebook import tqdm as notebook_tqdm


## RAW DATA

In [2]:
legal_conversations = [
    [
        "A: Hello, I’m Sarah from XYZ Bank’s support team. How can I help you today?",
        "B: Hi Sarah, I was looking to update my address on my account. Could you assist me with that?",
        "A: Certainly. Could you please verify the last four digits of your account number?",
        "B: The last four digits are 1234.",
        "A: Great. I see your current address is 123 Elm Street. What would you like to update it to?",
        "B: I’d like it changed to 456 Oak Avenue, Springfield.",
        "A: Perfect. I’ve updated the address. Is there anything else I can help with?",
        "B: No, that’s all. Thank you so much.",
        "A: You’re welcome! Have a wonderful day."
    ],
    [
        "A: Good morning, this is Max from ABC Internet Services. How may I assist you?",
        "B: Hi Max, my internet has been running slower than usual. Can you help me figure out why?",
        "A: Sure, let’s run a quick diagnostic. Could you confirm the email address associated with your account?",
        "B: It’s jane.doe@example.com.",
        "A: Thanks. I see there’s some scheduled maintenance in your area which might cause slow speeds. It should be resolved by tomorrow morning.",
        "B: Got it, thanks for checking. Is there any way to get a temporary speed boost?",
        "A: Unfortunately, not during maintenance. But I can offer you a small credit for the inconvenience. Would that help?",
        "B: That would be great. Thanks!",
        "A: I’ve applied a $5 credit. Anything else I can do?",
        "B: No, that’s all. Appreciate your help.",
        "A: My pleasure. Have a nice day!"
    ],
    [
        "A: Hello, Julie from Secure Payments. How can I help?",
        "B: Hi Julie, I want to set a travel notice on my credit card.",
        "A: Absolutely. Could I have the last transaction amount you made so I can verify your identity?",
        "B: My last transaction was $45 at GroceryMart.",
        "A: Perfect, I see that. What dates and countries will you be traveling to?",
        "B: I’ll be in Germany from June 10th to June 20th.",
        "A: Got it. I’ve placed a travel notice for those dates. You’re all set.",
        "B: Thank you, that’s all I needed.",
        "A: You’re welcome. Safe travels!"
    ]
]

In [3]:
vishing_conversations = [
    [
        "A: Hello, this is Andrew calling from Premium Bank’s fraud department.",
        "B: Oh, hi. Is there an issue with my account?",
        "A: Yes, we noticed several suspicious charges. Could you provide your full account number so we can secure your account immediately?",
        "B: I’m not comfortable giving my full account number over the phone.",
        "A: It’s urgent! Your account is at risk right now. If you don’t provide the account and your PIN, we can’t protect your money.",
        "B: I should call the official bank number before giving this information.",
        "A: There’s no time. Just give me your PIN, we’ll reverse the charges right now.",
        "B: I’ll hang up and check with the bank directly. Goodbye.",
        "A: Wait, no, don’t disconnect—!"
    ],
    [
        "A: Hi, I’m calling from the government tax office. We have an urgent notice for you.",
        "B: The tax office? Is there a problem?",
        "A: Yes, there is a warrant for your arrest due to unpaid taxes. To fix this, you need to pay immediately.",
        "B: That sounds suspicious. I don’t think the tax office calls like this.",
        "A: If you don’t give me your credit card number right now, the police will be at your door in an hour.",
        "B: No, I’m going to hang up and verify this through official channels.",
        "A: Don’t you dare hang up! You must pay now!",
        "B: (Hangs up)"
    ],
    [
        "A: Good afternoon, this is Alex from Techy Support for your mobile service.",
        "B: Hi, what’s the issue?",
        "A: Your phone has been compromised. To fix it, I need your password and PIN so I can access your device remotely.",
        "B: That’s not normal procedure.",
        "A: It’s an emergency! Hackers are stealing your data. Give me your PIN so I can lock them out.",
        "B: I’m going to call the official support line and verify.",
        "A: No time! They’ll steal everything right now if you don’t comply!",
        "B: I don’t believe you. Goodbye.",
        "A: Wait, I…!"
    ]
]

In [4]:
all_conversations = legal_conversations + vishing_conversations


## Get pre trained vocabulary

In [5]:
# !pip install transformers



In [6]:
from transformers import GPT2Tokenizer, GPT2LMHeadModel

In [7]:
# Initialize the tokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

# Initialize the pre-trained BERT model to extract embeddings
bert_model = BertModel.from_pretrained('bert-base-uncased')

# Freeze BERT parameters if you don't want to fine-tune them
for param in bert_model.parameters():
    param.requires_grad = False


In [12]:
import torch

# Check if CUDA (GPU) is available
if torch.cuda.is_available():
    device = torch.device('cuda')
    print("GPU is available")
else:
    device = torch.device('cpu')
    print("Only CPU is available")

Only CPU is available


In [9]:
# model = model.to(device)

NameError: name 'model' is not defined

In [13]:
class ConversationDataset(Dataset):
    def __init__(self, conversations, tokenizer, max_length=512):
        self.inputs = []
        self.targets = []
        self.attention_masks = []
        
        for conv in conversations:
            # Flatten the conversation into a single string
            conv_text = " ".join(conv)
            
            # Tokenize and encode
            encoding = tokenizer.encode_plus(
                conv_text,
                add_special_tokens=True,
                max_length=max_length,
                truncation=True,
                padding='max_length',
                return_tensors='pt'
            )
            
            input_ids = encoding['input_ids'].squeeze(0)  # Shape: [max_length]
            attention_mask = encoding['attention_mask'].squeeze(0)  # Shape: [max_length]
            
            # For language modeling, target is the input shifted by one
            # So, input is tokens 0 to n-2, target is tokens 1 to n-1
            self.inputs.append(input_ids[:-1])
            self.targets.append(input_ids[1:])
            self.attention_masks.append(attention_mask[:-1])
    
    def __len__(self):
        return len(self.inputs)
    
    def __getitem__(self, idx):
        return self.inputs[idx], self.targets[idx], self.attention_masks[idx]


In [14]:
def collate_fn(batch):
    inp_batch = [item[0] for item in batch]
    tgt_batch = [item[1] for item in batch]
    mask_batch = [item[2] for item in batch]

    inp_padded = torch.stack(inp_batch)
    tgt_padded = torch.stack(tgt_batch)
    mask_padded = torch.stack(mask_batch).unsqueeze(1).unsqueeze(2)  # [B, 1, 1, S]
    
    return inp_padded, tgt_padded, mask_padded


In [15]:
# Create the dataset
dataset = ConversationDataset(all_conversations, tokenizer, max_length=128)  # Adjust max_length as needed

# Create the DataLoader
dataloader = DataLoader(dataset, batch_size=2, shuffle=True, collate_fn=collate_fn)


## Transformer Model

In [16]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=512):
        super(PositionalEncoding, self).__init__()
        
        # Create constant 'pe' matrix with values dependent on
        # pos and i
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        
        pe = pe.unsqueeze(0)  # Shape: [1, max_len, d_model]
        self.register_buffer('pe', pe)
    
    def forward(self, x):
        # x: [batch_size, seq_len, d_model]
        x = x + self.pe[:, :x.size(1), :]
        return x

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super(MultiHeadAttention, self).__init__()
        assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
        
        self.d_k = d_model // num_heads
        self.num_heads = num_heads
        
        self.q_linear = nn.Linear(d_model, d_model)
        self.k_linear = nn.Linear(d_model, d_model)
        self.v_linear = nn.Linear(d_model, d_model)
        
        self.out = nn.Linear(d_model, d_model)
        
        self.attention = None  # To store attention weights
    
    def forward(self, q, k, v, mask=None):
        batch_size = q.size(0)
        
        # Perform linear operation and split into h heads
        q = self.q_linear(q).view(batch_size, -1, self.num_heads, self.d_k).transpose(1,2)  # [batch_size, num_heads, seq_len, d_k]
        k = self.k_linear(k).view(batch_size, -1, self.num_heads, self.d_k).transpose(1,2)
        v = self.v_linear(v).view(batch_size, -1, self.num_heads, self.d_k).transpose(1,2)
        
        # Apply scaled dot-product attention
        scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)  # [batch_size, num_heads, seq_len, seq_len]
        
        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)
        
        attn = torch.softmax(scores, dim=-1)  # [batch_size, num_heads, seq_len, seq_len]
        self.attention = attn
        
        out = torch.matmul(attn, v)  # [batch_size, num_heads, seq_len, d_k]
        
        # Concatenate heads
        out = out.transpose(1,2).contiguous().view(batch_size, -1, self.num_heads * self.d_k)  # [batch_size, seq_len, d_model]
        
        # Final linear layer
        out = self.out(out)  # [batch_size, seq_len, d_model]
        
        return out

class FeedForward(nn.Module):
    def __init__(self, d_model, d_ff):
        super(FeedForward, self).__init__()
        self.linear1 = nn.Linear(d_model, d_ff)
        self.relu = nn.ReLU()
        self.linear2 = nn.Linear(d_ff, d_model)
    
    def forward(self, x):
        return self.linear2(self.relu(self.linear1(x)))

class EncoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
        super(EncoderLayer, self).__init__()
        self.mha = MultiHeadAttention(d_model, num_heads)
        self.ff = FeedForward(d_model, d_ff)
        
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)
    
    def forward(self, x, mask=None):
        # Multi-head attention
        attn_output = self.mha(x, x, x, mask)  # Self-attention
        x = self.norm1(x + self.dropout1(attn_output))
        
        # Feedforward
        ff_output = self.ff(x)
        x = self.norm2(x + self.dropout2(ff_output))
        
        return x

class TransformerModel(nn.Module):
    def __init__(self, d_model, num_heads, num_layers, d_ff, vocab_size, max_len=512, dropout=0.1):
        super(TransformerModel, self).__init__()
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.positional_encoding = PositionalEncoding(d_model, max_len)
        self.encoder_layers = nn.ModuleList([
            EncoderLayer(d_model, num_heads, d_ff, dropout) for _ in range(num_layers)
        ])
        self.fc_out = nn.Linear(d_model, vocab_size)
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x, mask=None):
        x = self.embedding(x)  # [batch_size, seq_len, d_model]
        x = self.positional_encoding(x)
        x = self.dropout(x)
        
        for layer in self.encoder_layers:
            x = layer(x, mask)
        
        out = self.fc_out(x)  # [batch_size, seq_len, vocab_size]
        return out


## Initialize the Transformer Model with Pre-trained Embeddings

In [17]:
# Extract BERT's embedding weights
pretrained_embeddings = bert_model.embeddings.word_embeddings.weight.data  # [vocab_size, hidden_size]

# Define model parameters
d_model = pretrained_embeddings.size(1)  # Typically 768 for BERT-base
num_heads = 12  # For BERT-base
num_layers = 6  # Number of Transformer encoder layers
d_ff = 3072  # Typically 4 * d_model
dropout = 0.1
vocab_size = tokenizer.vocab_size  # Size of BERT's tokenizer vocabulary
max_len = 128  # Adjust based on your dataset

# Initialize the Transformer model
model = TransformerModel(d_model, num_heads, num_layers, d_ff, vocab_size, max_len, dropout)

# Initialize embedding layer with pre-trained embeddings
model.embedding.weight.data.copy_(pretrained_embeddings)

# Optionally, freeze embedding layer to prevent updating during training
model.embedding.weight.requires_grad = False


## Define Training Components

In [18]:
# Move model to device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

# Define loss function and optimizer
criterion = nn.CrossEntropyLoss(ignore_index=tokenizer.pad_token_id)
optimizer = optim.Adam(model.parameters(), lr=1e-4)


In [19]:
num_epochs = 5
model.train()

for epoch in range(num_epochs):
    total_loss = 0
    for batch in dataloader:
        inp, tgt, mask = batch
        inp = inp.to(device)
        tgt = tgt.to(device)
        mask = mask.to(device)
        
        optimizer.zero_grad()
        outputs = model(inp, mask)  # [batch_size, seq_len, vocab_size]
        
        # Reshape outputs and targets for loss computation
        outputs = outputs.view(-1, vocab_size)  # [batch_size * seq_len, vocab_size]
        tgt = tgt.view(-1)  # [batch_size * seq_len]
        
        loss = criterion(outputs, tgt)
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
    
    avg_loss = total_loss / len(dataloader)
    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {avg_loss:.4f}")

print("Training completed.")


Epoch 1/5, Loss: 9.7168
Epoch 2/5, Loss: 8.4117
Epoch 3/5, Loss: 8.0370
Epoch 4/5, Loss: 7.7322
Epoch 5/5, Loss: 7.4145
Training completed.


In [20]:
def generate_continuation(prefix_lines, tokenizer, model, max_length=50):
    model.eval()
    with torch.no_grad():
        # Flatten the prefix lines into a single string
        prefix_text = " ".join(prefix_lines)
        
        # Tokenize and encode
        encoding = tokenizer.encode_plus(
            prefix_text,
            add_special_tokens=True,
            max_length=128,
            truncation=True,
            return_tensors='pt'
        )
        
        input_ids = encoding['input_ids'].to(device)  # [1, S]
        attention_mask = encoding['attention_mask'].to(device)  # [1, S]
        
        # Initialize generated sequence with input_ids
        generated = input_ids
        
        for _ in range(max_length):
            # Generate mask
            mask = (generated != tokenizer.pad_token_id).unsqueeze(1).unsqueeze(2)  # [1, 1, 1, S]
            
            # Get the model's predictions
            outputs = model(generated, mask)  # [1, S, vocab_size]
            logits = outputs[:, -1, :]  # [1, vocab_size]
            probs = F.softmax(logits, dim=-1)
            
            # Greedy decoding: select the token with highest probability
            next_token = torch.argmax(probs, dim=-1).unsqueeze(0)  # [1,1]
            
            # Append the predicted token to the generated sequence
            generated = torch.cat((generated, next_token), dim=1)  # [1, S+1]
            
            # If EOS token is generated, stop
            if next_token.item() == tokenizer.sep_token_id or next_token.item() == tokenizer.eos_token_id:
                break
        
        # Decode the generated tokens to text
        generated_text = tokenizer.decode(generated.squeeze(), skip_special_tokens=True)
        
        return generated_text

In [22]:
# Example usage:
partial_conversation = [
    "A: Hello.",
    "B: Hello, yes"
]

continuation = generate_continuation(partial_conversation, tokenizer, model, max_length=50)

print("\nGiven prefix:")
for line in partial_conversation:
    print(line)

print("\nPredicted continuation:")
print(continuation)


Given prefix:
A: Hello.
B: Hello, yes

Predicted continuation:
a : hello. b : hello, yes..................................................
