In [1]:
import math
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence

In [2]:

###############################################################
# Conversations (Legitimate and Vishing) in List Format
###############################################################

In [3]:
legal_conversations = [
    [
        "A: Hello, I’m Sarah from XYZ Bank’s support team. How can I help you today?",
        "B: Hi Sarah, I was looking to update my address on my account. Could you assist me with that?",
        "A: Certainly. Could you please verify the last four digits of your account number?",
        "B: The last four digits are 1234.",
        "A: Great. I see your current address is 123 Elm Street. What would you like to update it to?",
        "B: I’d like it changed to 456 Oak Avenue, Springfield.",
        "A: Perfect. I’ve updated the address. Is there anything else I can help with?",
        "B: No, that’s all. Thank you so much.",
        "A: You’re welcome! Have a wonderful day."
    ],
    [
        "A: Good morning, this is Max from ABC Internet Services. How may I assist you?",
        "B: Hi Max, my internet has been running slower than usual. Can you help me figure out why?",
        "A: Sure, let’s run a quick diagnostic. Could you confirm the email address associated with your account?",
        "B: It’s jane.doe@example.com.",
        "A: Thanks. I see there’s some scheduled maintenance in your area which might cause slow speeds. It should be resolved by tomorrow morning.",
        "B: Got it, thanks for checking. Is there any way to get a temporary speed boost?",
        "A: Unfortunately, not during maintenance. But I can offer you a small credit for the inconvenience. Would that help?",
        "B: That would be great. Thanks!",
        "A: I’ve applied a $5 credit. Anything else I can do?",
        "B: No, that’s all. Appreciate your help.",
        "A: My pleasure. Have a nice day!"
    ],
    [
        "A: Hello, Julie from Secure Payments. How can I help?",
        "B: Hi Julie, I want to set a travel notice on my credit card.",
        "A: Absolutely. Could I have the last transaction amount you made so I can verify your identity?",
        "B: My last transaction was $45 at GroceryMart.",
        "A: Perfect, I see that. What dates and countries will you be traveling to?",
        "B: I’ll be in Germany from June 10th to June 20th.",
        "A: Got it. I’ve placed a travel notice for those dates. You’re all set.",
        "B: Thank you, that’s all I needed.",
        "A: You’re welcome. Safe travels!"
    ]
]

In [4]:
vishing_conversations = [
    [
        "A: Hello, this is Andrew calling from Premium Bank’s fraud department.",
        "B: Oh, hi. Is there an issue with my account?",
        "A: Yes, we noticed several suspicious charges. Could you provide your full account number so we can secure your account immediately?",
        "B: I’m not comfortable giving my full account number over the phone.",
        "A: It’s urgent! Your account is at risk right now. If you don’t provide the account and your PIN, we can’t protect your money.",
        "B: I should call the official bank number before giving this information.",
        "A: There’s no time. Just give me your PIN, we’ll reverse the charges right now.",
        "B: I’ll hang up and check with the bank directly. Goodbye.",
        "A: Wait, no, don’t disconnect—!"
    ],
    [
        "A: Hi, I’m calling from the government tax office. We have an urgent notice for you.",
        "B: The tax office? Is there a problem?",
        "A: Yes, there is a warrant for your arrest due to unpaid taxes. To fix this, you need to pay immediately.",
        "B: That sounds suspicious. I don’t think the tax office calls like this.",
        "A: If you don’t give me your credit card number right now, the police will be at your door in an hour.",
        "B: No, I’m going to hang up and verify this through official channels.",
        "A: Don’t you dare hang up! You must pay now!",
        "B: (Hangs up)"
    ],
    [
        "A: Good afternoon, this is Alex from Techy Support for your mobile service.",
        "B: Hi, what’s the issue?",
        "A: Your phone has been compromised. To fix it, I need your password and PIN so I can access your device remotely.",
        "B: That’s not normal procedure.",
        "A: It’s an emergency! Hackers are stealing your data. Give me your PIN so I can lock them out.",
        "B: I’m going to call the official support line and verify.",
        "A: No time! They’ll steal everything right now if you don’t comply!",
        "B: I don’t believe you. Goodbye.",
        "A: Wait, I…!"
    ]
]

In [5]:
all_conversations = legal_conversations + vishing_conversations

In [6]:
###############################################################
# 1. Preprocessing and Dataset
###############################################################

In [7]:
# Basic tokenization (whitespace)
def tokenize_line(line):
    return line.strip().split()

# Special tokens
PAD_TOKEN = "<pad>"
UNK_TOKEN = "<unk>"
BOS_TOKEN = "<bos>"
EOS_TOKEN = "<eos>"

In [8]:
# Build vocab
all_tokens = []
for conv in all_conversations:
    for line in conv:
        all_tokens.extend(tokenize_line(line))

vocab = {PAD_TOKEN:0, UNK_TOKEN:1, BOS_TOKEN:2, EOS_TOKEN:3}
for tok in all_tokens:
    if tok not in vocab:
        vocab[tok] = len(vocab)

inv_vocab = {v:k for k,v in vocab.items()}
vocab_size = len(vocab)


In [9]:
def encode_tokens(tokens, vocab):
    return [vocab.get(t, vocab[UNK_TOKEN]) for t in tokens]


In [10]:
class ConversationDataset(Dataset):
    def __init__(self, conversations, vocab):
        self.data = []
        for conv in conversations:
            # Flatten the conversation into a single sequence
            # Add BOS at start and EOS at end
            seq_tokens = [BOS_TOKEN]
            for line in conv:
                seq_tokens.extend(tokenize_line(line))
            seq_tokens.append(EOS_TOKEN)
            encoded = torch.tensor(encode_tokens(seq_tokens, vocab), dtype=torch.long)
            self.data.append(encoded)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        # For language modeling:
        # Input: seq[:-1], Target: seq[1:]
        seq = self.data[idx]
        inp = seq[:-1]
        tgt = seq[1:]
        return inp, tgt

In [11]:
def collate_fn(batch):
    inp_batch = [b[0] for b in batch]
    tgt_batch = [b[1] for b in batch]

    inp_padded = pad_sequence(inp_batch, batch_first=True, padding_value=vocab[PAD_TOKEN])
    tgt_padded = pad_sequence(tgt_batch, batch_first=True, padding_value=vocab[PAD_TOKEN])

    src_mask = (inp_padded != vocab[PAD_TOKEN]).unsqueeze(1).unsqueeze(2)
    return inp_padded, tgt_padded, src_mask

In [12]:

dataset = ConversationDataset(all_conversations, vocab)
dataloader = DataLoader(dataset, batch_size=2, shuffle=True, collate_fn=collate_fn)


In [13]:
###############################################################
# 2. Transformer Language Model Definition
###############################################################

In [14]:
class Embedding(nn.Module):
    def __init__(self, vocab_size, d_model):
        super(Embedding, self).__init__()
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.d_model = d_model

    def forward(self, x):
        return self.embedding(x) * math.sqrt(self.d_model)


In [15]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.encoding = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))

        self.encoding[:, 0::2] = torch.sin(position * div_term)
        self.encoding[:, 1::2] = torch.cos(position * div_term)
        self.encoding = self.encoding.unsqueeze(0)

    def forward(self, x):
        return x + self.encoding[:, :x.size(1), :].to(x.device)


In [16]:
def scaled_dot_product_attention(query, key, value, mask=None):
    d_k = query.size(-1)
    scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)
    if mask is not None:
        scores = scores.masked_fill(mask == 0, float('-inf'))
    attention_weights = torch.softmax(scores, dim=-1)
    return torch.matmul(attention_weights, value), attention_weights

In [17]:
class MultiHeadAttention(nn.Module):
    def __init__(self, num_heads, d_model):
        super(MultiHeadAttention, self).__init__()
        assert d_model % num_heads == 0
        self.d_k = d_model // num_heads
        self.num_heads = num_heads
        self.linear_q = nn.Linear(d_model, d_model)
        self.linear_k = nn.Linear(d_model, d_model)
        self.linear_v = nn.Linear(d_model, d_model)
        self.fc_out = nn.Linear(d_model, d_model)

    def forward(self, query, key, value, mask=None):
        batch_size = query.size(0)
        def transform(x, linear_layer):
            x = linear_layer(x)
            return x.view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
        query = transform(query, self.linear_q)
        key = transform(key, self.linear_k)
        value = transform(value, self.linear_v)

        attention_output, _ = scaled_dot_product_attention(query, key, value, mask)
        attention_output = attention_output.transpose(1, 2).contiguous().view(batch_size, -1, self.num_heads * self.d_k)
        return self.fc_out(attention_output)

In [18]:
class PositionwiseFeedforward(nn.Module):
    def __init__(self, d_model, d_ff, dropout=0.1):
        super(PositionwiseFeedforward, self).__init__()
        self.linear1 = nn.Linear(d_model, d_ff)
        self.dropout = nn.Dropout(dropout)
        self.linear2 = nn.Linear(d_ff, d_model)

    def forward(self, x):
        return self.linear2(self.dropout(F.relu(self.linear1(x))))

In [19]:
class EncoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
        super(EncoderLayer, self).__init__()
        self.attention = MultiHeadAttention(num_heads, d_model)
        self.feed_forward = PositionwiseFeedforward(d_model, d_ff, dropout)
        self.layer_norm1 = nn.LayerNorm(d_model)
        self.layer_norm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, src, src_mask):
        src2 = self.attention(src, src, src, src_mask)
        src = self.layer_norm1(src + self.dropout(src2))
        
        src2 = self.feed_forward(src)
        src = self.layer_norm2(src + self.dropout(src2))
        return src

In [20]:
class TransformerLanguageModel(nn.Module):
    def __init__(self, vocab_size, d_model=128, num_layers=2, num_heads=4, d_ff=512, dropout=0.1):
        super(TransformerLanguageModel, self).__init__()
        self.embedding = Embedding(vocab_size, d_model)
        self.positional_encoding = PositionalEncoding(d_model)
        self.layers = nn.ModuleList([EncoderLayer(d_model, num_heads, d_ff, dropout) for _ in range(num_layers)])
        self.fc_out = nn.Linear(d_model, vocab_size)

    def forward(self, src, src_mask):
        x = self.embedding(src)
        x = self.positional_encoding(x)
        for layer in self.layers:
            x = layer(x, src_mask)
        logits = self.fc_out(x)
        return logits


In [21]:
###############################################################
# 3. Training
###############################################################

In [22]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = TransformerLanguageModel(vocab_size).to(device)
criterion = nn.CrossEntropyLoss(ignore_index=vocab[PAD_TOKEN])
optimizer = optim.Adam(model.parameters(), lr=0.001)

  from .autonotebook import tqdm as notebook_tqdm


In [23]:
epochs = 5
model.train()
for epoch in range(epochs):
    total_loss = 0
    for inp_padded, tgt_padded, src_mask in dataloader:
        inp_padded = inp_padded.to(device)
        tgt_padded = tgt_padded.to(device)
        src_mask = src_mask.to(device)

        optimizer.zero_grad()
        logits = model(inp_padded, src_mask)
        loss = criterion(logits.view(-1, vocab_size), tgt_padded.view(-1))
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    avg_loss = total_loss / len(dataloader)
    print(f"Epoch {epoch+1}, Loss: {avg_loss:.4f}")

print("Training completed.")

Epoch 1, Loss: 5.9809
Epoch 2, Loss: 5.4180
Epoch 3, Loss: 5.0699
Epoch 4, Loss: 4.7926
Epoch 5, Loss: 4.4776
Training completed.


In [24]:
###############################################################
# 4. Inference
#
# Given a partial conversation, predict the rest of the conversation.
###############################################################


In [25]:
def generate_continuation(prefix_lines, model, max_length=50):
    model.eval()
    with torch.no_grad():
        # Encode prefix
        prefix_tokens = [BOS_TOKEN]
        for line in prefix_lines:
            prefix_tokens.extend(tokenize_line(line))
        
        inp = torch.tensor([encode_tokens(prefix_tokens, vocab)], dtype=torch.long).to(device)
        for _ in range(max_length):
            src_mask = (inp != vocab[PAD_TOKEN]).unsqueeze(1).unsqueeze(2)
            logits = model(inp, src_mask)
            next_token_logits = logits[:, -1, :]
            next_token = torch.argmax(next_token_logits, dim=-1).item()
            if next_token == vocab[EOS_TOKEN]:
                break
            inp = torch.cat([inp, torch.tensor([[next_token]]).to(device)], dim=1)

        generated_ids = inp.squeeze(0).tolist()
        # Remove BOS
        generated_ids = generated_ids[1:]
        return [inv_vocab[i] for i in generated_ids]


In [26]:
# Example: partial conversation
partial_conversation = [
    "A: Hello, I am John from your bank.",
    "B: Hello, how can I help you?"
]

predicted_tokens = generate_continuation(partial_conversation, model)
print("\nGiven prefix:")
for line in partial_conversation:
    print(line)

print("\nPredicted continuation:")
print(" ".join(predicted_tokens))


Given prefix:
A: Hello, I am John from your bank.
B: Hello, how can I help you?

Predicted continuation:
A: Hello, I <unk> <unk> from your <unk> B: Hello, <unk> can I help you? B: Hi your mobile your mobile your mobile your mobile I can I can I can I can I can I can I can I can I can I can I can I can I can I can I can I can I can I can I can I can
