In [1]:
import math
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

In [2]:
########################################
# 1. Sample Conversation Data
########################################

In [3]:
legal_conversations = [
    [
        "A: Hello, I’m Sarah from XYZ Bank’s support team. How can I help you today?",
        "B: Hi Sarah, I was looking to update my address on my account. Could you assist me with that?",
        "A: Certainly. Could you please verify the last four digits of your account number?",
        "B: The last four digits are 1234.",
        "A: Great. I see your current address is 123 Elm Street. What would you like to update it to?",
        "B: I’d like it changed to 456 Oak Avenue, Springfield.",
        "A: Perfect. I’ve updated the address. Is there anything else I can help with?",
        "B: No, that’s all. Thank you so much.",
        "A: You’re welcome! Have a wonderful day."
    ],
    [
        "A: Good morning, this is Max from ABC Internet Services. How may I assist you?",
        "B: Hi Max, my internet has been running slower than usual. Can you help me figure out why?",
        "A: Sure, let’s run a quick diagnostic. Could you confirm the email address associated with your account?",
        "B: It’s jane.doe@example.com.",
        "A: Thanks. I see there’s some scheduled maintenance in your area which might cause slow speeds.",
        "B: Got it, thanks for checking. Is there any way to get a temporary speed boost?",
        "A: Unfortunately, not during maintenance. But I can offer you a small credit for the inconvenience.",
        "B: That would be great. Thanks!",
        "A: I’ve applied a $5 credit. Anything else I can do?",
        "B: No, that’s all. Appreciate your help.",
        "A: My pleasure. Have a nice day!"
    ]
]

vishing_conversations = [
    [
        "A: Hello, this is Andrew calling from Premium Bank’s fraud department.",
        "B: Oh, hi. Is there an issue with my account?",
        "A: Yes, we noticed several suspicious charges. Could you provide your full account number?",
        "B: I’m not comfortable giving my full account number over the phone.",
        "A: It’s urgent! If you don’t provide the account and your PIN, we can’t protect your money.",
        "B: I should call the official bank number before giving this information.",
        "A: There’s no time. Just give me your PIN now!",
        "B: I’ll hang up and check with the bank directly. Goodbye.",
        "A: Wait, no, don’t disconnect—!"
    ],
    [
        "A: Hi, I’m calling from the government tax office. We have an urgent notice for you.",
        "B: The tax office? Is there a problem?",
        "A: Yes, there is a warrant for your arrest due to unpaid taxes. You need to pay immediately.",
        "B: That sounds suspicious. I don’t think the tax office calls like this.",
        "A: If you don’t give me your credit card number, the police will be at your door soon.",
        "B: No, I’m going to hang up and verify through official channels.",
        "A: Don’t hang up! You must pay now!",
        "B: (Hangs up)"
    ]
]

all_conversations = legal_conversations + vishing_conversations

In [4]:
########################################
# 2. Build a Vocabulary
########################################

In [5]:
SPECIAL_TOKENS = ["<PAD>", "<BOS>", "<EOS>", "<UNK>"]
PAD_TOKEN, BOS_TOKEN, EOS_TOKEN, UNK_TOKEN = SPECIAL_TOKENS

In [6]:
class Vocabulary:
    def __init__(self):
        self.token2idx = {}
        self.idx2token = []
        self.pad_index = None
        self.bos_index = None
        self.eos_index = None
        self.unk_index = None
    
    def build_vocab(self, text_list, min_freq=1):
        """
        text_list: A list of strings from which to build a vocabulary.
        min_freq : Minimum frequency for a token to be included (optional).
        """
        # Step 1: Collect frequency
        freq = {}
        for text in text_list:
            for token in text.split():
                freq[token] = freq.get(token, 0) + 1
        
        # Step 2: Initialize with special tokens
        idx = 0
        for st in SPECIAL_TOKENS:
            self.token2idx[st] = idx
            self.idx2token.append(st)
            idx += 1
        
        # Step 3: Add tokens based on frequency
        for token, count in freq.items():
            if count >= min_freq and token not in self.token2idx:
                self.token2idx[token] = idx
                self.idx2token.append(token)
                idx += 1
        
        # Step 4: Store indices for quick access
        self.pad_index = self.token2idx[PAD_TOKEN]
        self.bos_index = self.token2idx[BOS_TOKEN]
        self.eos_index = self.token2idx[EOS_TOKEN]
        self.unk_index = self.token2idx[UNK_TOKEN]
    
    def tokenize(self, text):
        # Simple whitespace split; in real use-cases consider advanced tokenization
        return text.split()
    
    def numericalize(self, text):
        # Convert text to list of token indices
        tokens = self.tokenize(text)
        return [self.token2idx.get(t, self.unk_index) for t in tokens]
    
    def denumericalize(self, indices):
        # Convert list of token indices back to text
        return [self.idx2token[idx] for idx in indices]
    
    @property
    def vocab_size(self):
        return len(self.idx2token)

In [7]:
# Flatten all lines in the dataset to build vocabulary
all_lines = []
for conv in all_conversations:
    for line in conv:
        # We'll just keep raw text. 
        # For advanced usage, remove punctuation, lower-case, etc.
        line_clean = line.strip()
        all_lines.append(line_clean)

vocab = Vocabulary()
vocab.build_vocab(all_lines, min_freq=1)
vocab_size = vocab.vocab_size
print("Vocabulary size:", vocab_size)


Vocabulary size: 260


In [8]:
########################################
# 3. Preparing the Dataset
########################################


In [9]:
class ConversationDataset(Dataset):
    """
    For each conversation, we flatten the lines into one 
    sequence:  <BOS> line1 <EOS> line2 <EOS> line3 <EOS> ...
    The next-token prediction objective:
      - Input: [<BOS>, line1 tokens, <EOS>, line2 tokens, ...]
      - Target: same sequence shifted 1 to the right
    """
    def __init__(self, conversations, vocab, max_length=128):
        super().__init__()
        self.samples = []
        self.vocab = vocab
        self.max_length = max_length
        
        for conv in conversations:
            # Flatten conversation lines with <EOS> in between
            # e.g. "<BOS> A: Hello ... <EOS> B: Hi ... <EOS> ..."
            token_list = [vocab.bos_index]  # start with <BOS>
            
            for line in conv:
                line_tokens = vocab.numericalize(line)
                # Add line tokens + <EOS>
                token_list.extend(line_tokens)
                token_list.append(vocab.eos_index)
            
            # If conversation too long, truncate
            if len(token_list) > self.max_length:
                token_list = token_list[:self.max_length]
            
            self.samples.append(token_list)
    
    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, idx):
        return self.samples[idx]

In [10]:
def collate_fn(batch):
    """
    batch: list of token sequences of different lengths.
    We pad them to the max length in the batch.
    Then create input/target by shifting by 1 for next-token prediction.
    """
    # Find the longest sequence in the batch
    max_len = max(len(seq) for seq in batch)
    
    padded_inp = []
    padded_tgt = []
    
    for seq in batch:
        # Input is seq[:-1], target is seq[1:]
        # But we first pad to max_len
        inp_seq = seq[:-1]  # all but last
        tgt_seq = seq[1:]   # all but first
        
        # Pad input and target to max_len-1
        if len(inp_seq) < max_len-1:
            inp_seq += [vocab.pad_index] * (max_len - 1 - len(inp_seq))
        if len(tgt_seq) < max_len-1:
            tgt_seq += [vocab.pad_index] * (max_len - 1 - len(tgt_seq))
        
        padded_inp.append(inp_seq)
        padded_tgt.append(tgt_seq)
    
    # Convert to tensors
    inp_tensor = torch.tensor(padded_inp, dtype=torch.long)
    tgt_tensor = torch.tensor(padded_tgt, dtype=torch.long)
    return inp_tensor, tgt_tensor


In [11]:

dataset = ConversationDataset(all_conversations, vocab, max_length=128)
dataloader = DataLoader(dataset, batch_size=2, shuffle=True, collate_fn=collate_fn)

In [12]:
########################################
# 4. Decoder-Only Transformer (from scratch)
########################################

In [13]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        
        pe = pe.unsqueeze(0)  # shape: [1, max_len, d_model]
        self.register_buffer('pe', pe)
    
    def forward(self, x):
        """
        x: [batch_size, seq_len, d_model]
        We add positional encoding to x.
        """
        seq_len = x.size(1)
        # x + pe[:, :seq_len, :]
        return x + self.pe[:, :seq_len, :]

In [14]:
class MultiHeadSelfAttention(nn.Module):
    def __init__(self, d_model, n_heads):
        super().__init__()
        assert d_model % n_heads == 0, "d_model must be divisible by n_heads."
        self.d_model = d_model
        self.n_heads = n_heads
        self.d_k = d_model // n_heads
        
        self.q_linear = nn.Linear(d_model, d_model)
        self.k_linear = nn.Linear(d_model, d_model)
        self.v_linear = nn.Linear(d_model, d_model)
        
        self.out = nn.Linear(d_model, d_model)
    
    def forward(self, x, mask=None):
        """
        x: [batch_size, seq_len, d_model]
        mask: [batch_size, seq_len, seq_len] or broadcastable shape
        """
        bsz, seq_len, _ = x.shape
        
        # 1) Linear projection
        q = self.q_linear(x)  # [bsz, seq_len, d_model]
        k = self.k_linear(x)
        v = self.v_linear(x)
        
        # 2) Split into heads
        q = q.view(bsz, seq_len, self.n_heads, self.d_k).transpose(1, 2)  # [bsz, n_heads, seq_len, d_k]
        k = k.view(bsz, seq_len, self.n_heads, self.d_k).transpose(1, 2)
        v = v.view(bsz, seq_len, self.n_heads, self.d_k).transpose(1, 2)
        
        # 3) Scaled dot-product attention
        scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)  # [bsz, n_heads, seq_len, seq_len]
        
        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))
        
        attn = torch.softmax(scores, dim=-1)  # [bsz, n_heads, seq_len, seq_len]
        out = torch.matmul(attn, v)  # [bsz, n_heads, seq_len, d_k]
        
        # 4) Recombine heads
        out = out.transpose(1, 2).contiguous().view(bsz, seq_len, self.d_model)  # [bsz, seq_len, d_model]
        
        # 5) Final linear layer
        out = self.out(out)
        return out


In [15]:
class FeedForward(nn.Module):
    def __init__(self, d_model, d_ff, dropout=0.1):
        super().__init__()
        self.linear1 = nn.Linear(d_model, d_ff)
        self.dropout = nn.Dropout(dropout)
        self.linear2 = nn.Linear(d_ff, d_model)
        self.relu = nn.ReLU()
    
    def forward(self, x):
        x = self.linear1(x)
        x = self.relu(x)
        x = self.dropout(x)
        x = self.linear2(x)
        return x


In [16]:
class DecoderLayer(nn.Module):
    def __init__(self, d_model, n_heads, d_ff, dropout=0.1):
        super().__init__()
        self.self_attn = MultiHeadSelfAttention(d_model, n_heads)
        self.layernorm1 = nn.LayerNorm(d_model)
        
        self.feed_forward = FeedForward(d_model, d_ff, dropout)
        self.layernorm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x, mask=None):
        # Self-attention
        attn_out = self.self_attn(x, mask=mask)
        x = x + self.dropout(attn_out)
        x = self.layernorm1(x)
        
        # Feed-forward
        ff_out = self.feed_forward(x)
        x = x + self.dropout(ff_out)
        x = self.layernorm2(x)
        
        return x

In [17]:
class DecoderOnlyTransformer(nn.Module):
    def __init__(self, 
                 vocab_size,
                 d_model=256,
                 n_heads=4,
                 num_layers=3,
                 d_ff=1024,
                 max_len=512,
                 dropout=0.1):
        super().__init__()
        self.vocab_size = vocab_size
        self.d_model = d_model
        
        self.token_emb = nn.Embedding(vocab_size, d_model)
        self.pos_emb = PositionalEncoding(d_model, max_len)
        
        self.layers = nn.ModuleList([
            DecoderLayer(d_model, n_heads, d_ff, dropout) for _ in range(num_layers)
        ])
        
        self.fc_out = nn.Linear(d_model, vocab_size)
    
    def generate_causal_mask(self, seq_len, device):
        """
        Generate an upper-triangular causal mask 
        so each token can only attend to tokens on its left.
        Shape: [seq_len, seq_len]
        """
        mask = torch.triu(torch.ones((seq_len, seq_len), device=device), diagonal=1)
        # 1 = no access (future), 0 = can attend
        mask = (mask == 0)  # invert
        return mask
    
    def forward(self, x):
        """
        x: [batch_size, seq_len]
        Returns: [batch_size, seq_len, vocab_size]
        """
        bsz, seq_len = x.shape
        device = x.device
        
        # Embedding
        x = self.token_emb(x)  # [bsz, seq_len, d_model]
        x = self.pos_emb(x)    # add positional encodings
        
        # Causal mask
        mask = self.generate_causal_mask(seq_len, device)  # [seq_len, seq_len]
        # We need a broadcastable shape [bsz, n_heads, seq_len, seq_len], 
        # but let's just keep it [seq_len, seq_len] if we handle it in attention.
        
        for layer in self.layers:
            x = layer(x, mask=mask)
        
        logits = self.fc_out(x)  # [bsz, seq_len, vocab_size]
        return logits

In [18]:
########################################
# 5. Training Loop
########################################

In [19]:
# Hyperparameters
d_model = 256
n_heads = 4
num_layers = 3
d_ff = 1024
dropout = 0.1
learning_rate = 1e-3
num_epochs = 5

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("Using device:", device)


Using device: cpu


In [20]:
model = DecoderOnlyTransformer(
    vocab_size=vocab_size,
    d_model=d_model,
    n_heads=n_heads,
    num_layers=num_layers,
    d_ff=d_ff,
    max_len=128,
    dropout=dropout
)
model.to(device)

DecoderOnlyTransformer(
  (token_emb): Embedding(260, 256)
  (pos_emb): PositionalEncoding()
  (layers): ModuleList(
    (0-2): 3 x DecoderLayer(
      (self_attn): MultiHeadSelfAttention(
        (q_linear): Linear(in_features=256, out_features=256, bias=True)
        (k_linear): Linear(in_features=256, out_features=256, bias=True)
        (v_linear): Linear(in_features=256, out_features=256, bias=True)
        (out): Linear(in_features=256, out_features=256, bias=True)
      )
      (layernorm1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
      (feed_forward): FeedForward(
        (linear1): Linear(in_features=256, out_features=1024, bias=True)
        (dropout): Dropout(p=0.1, inplace=False)
        (linear2): Linear(in_features=1024, out_features=256, bias=True)
        (relu): ReLU()
      )
      (layernorm2): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
  )
  (fc_out): Linear(in_features=256, out_features=2

In [21]:
criterion = nn.CrossEntropyLoss(ignore_index=vocab.pad_index)
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

In [22]:
model.train()
for epoch in range(num_epochs):
    total_loss = 0
    for inp, tgt in dataloader:
        inp, tgt = inp.to(device), tgt.to(device)
        
        optimizer.zero_grad()
        logits = model(inp)  # [batch_size, seq_len, vocab_size]
        
        # Flatten for loss calculation: (batch * seq_len, vocab_size)
        # And compare with target: (batch * seq_len)
        logits_2d = logits.view(-1, vocab_size)
        tgt_1d = tgt.view(-1)
        
        loss = criterion(logits_2d, tgt_1d)
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
    
    avg_loss = total_loss / len(dataloader)
    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {avg_loss:.4f}")

print("Training finished.\n")

Epoch 1/5, Loss: 5.5071
Epoch 2/5, Loss: 4.6805
Epoch 3/5, Loss: 4.0514
Epoch 4/5, Loss: 3.3875
Epoch 5/5, Loss: 2.8585
Training finished.



In [23]:
########################################
# 6. Generation (Autoregressive Inference)
########################################


In [24]:
def generate_text(prefix, model, vocab, max_new_tokens=20):
    """
    prefix: list of tokens (indices) that serve as the context.
    model: the trained model.
    Return the entire sequence (prefix + newly generated tokens).
    """
    model.eval()
    with torch.no_grad():
        x = torch.tensor(prefix, dtype=torch.long, device=device).unsqueeze(0)
        # shape: [1, prefix_len]
        
        for _ in range(max_new_tokens):
            # Forward pass
            logits = model(x)  # [1, current_len, vocab_size]
            
            # Get last token's logits
            last_token_logits = logits[:, -1, :]  # [1, vocab_size]
            
            # Greedy
            next_token = torch.argmax(last_token_logits, dim=-1)  # [1]
            # next_token = sample_from_probs(last_token_logits)  # for sampling
            
            if next_token.item() == vocab.eos_index:
                # Stop if we hit <EOS>
                break
            
            # Append to sequence
            x = torch.cat([x, next_token.unsqueeze(0)], dim=1)
        
        return x.squeeze(0).tolist()  # [full_sequence_length]

# Example usage: Provide a partial line as prefix
prefix_text = "A: Hello, I am John from your bank."
prefix_tokens = [vocab.bos_index] + vocab.numericalize(prefix_text)

generated_indices = generate_text(prefix_tokens, model, vocab, max_new_tokens=30)
generated_text = vocab.denumericalize(generated_indices)

print("Generated token indices:", generated_indices)
print("\nGenerated text (tokens):")
print(generated_text)

# Convert to string (roughly)
output_string = " ".join(generated_text)
print("\nFinal Output Text:")
print(output_string)

Generated token indices: [1, 4, 5, 15, 3, 3, 8, 43, 3, 14, 15, 14, 15, 14, 15, 14, 15, 14, 15, 14, 15, 14, 15, 14, 15, 14, 15, 14, 15, 14, 15, 14, 15, 14, 15, 14, 15, 14, 15]

Generated text (tokens):
['<BOS>', 'A:', 'Hello,', 'I', '<UNK>', '<UNK>', 'from', 'your', '<UNK>', 'can', 'I', 'can', 'I', 'can', 'I', 'can', 'I', 'can', 'I', 'can', 'I', 'can', 'I', 'can', 'I', 'can', 'I', 'can', 'I', 'can', 'I', 'can', 'I', 'can', 'I', 'can', 'I', 'can', 'I']

Final Output Text:
<BOS> A: Hello, I <UNK> <UNK> from your <UNK> can I can I can I can I can I can I can I can I can I can I can I can I can I can I can I


In [25]:
partial_conversation = [
    "A: Hello, I am John from your bank.",
    "B: Hi John, I was expecting your call."
]


In [26]:
# 1. Combine the partial lines into a single string
prefix_text = " ".join(partial_conversation)

# 2. Convert text to token indices
prefix_tokens = [vocab.bos_index] + vocab.numericalize(prefix_text)

# 3. Generate text
generated_indices = generate_text(prefix_tokens, model, vocab, max_new_tokens=50)
generated_text = vocab.denumericalize(generated_indices)

# 4. Print the result
print("Generated token indices:", generated_indices)
print("\nGenerated text (tokens):")
print(generated_text)

# 5. Optionally join tokens into a single string for readability
output_string = " ".join(generated_text)
print("\nFinal Output Text:")
print(output_string)


Generated token indices: [1, 4, 5, 15, 3, 3, 8, 43, 3, 19, 20, 3, 15, 22, 3, 43, 3, 17, 31]

Generated text (tokens):
['<BOS>', 'A:', 'Hello,', 'I', '<UNK>', '<UNK>', 'from', 'your', '<UNK>', 'B:', 'Hi', '<UNK>', 'I', 'was', '<UNK>', 'your', '<UNK>', 'you', 'assist']

Final Output Text:
<BOS> A: Hello, I <UNK> <UNK> from your <UNK> B: Hi <UNK> I was <UNK> your <UNK> you assist
