In [1]:
import pickle
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence
from tqdm import tqdm
import math
import numpy as np

from accelerate import Accelerator, notebook_launcher

In [2]:
token_to_idx = {(i, j): i * 8 + j + 1 for i in range(8) for j in range(8)} | {"up": 65, "down": 66, "left": 67, "right": 68}

class SequenceDataset(Dataset):
    """A dataset class for handling sequence data."""
    def __init__(self, data, token_to_idx):
        self.data = data
        self.token_to_idx = token_to_idx
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        X_sequence, Y_sequence = self.data[idx]
        X_indices = [self.token_to_idx[token] for token in X_sequence]
        Y_indices = [self.token_to_idx[token] for token in Y_sequence]
        return torch.tensor(X_indices, dtype=torch.long), torch.tensor(Y_indices, dtype=torch.long)

def collate_fn(batch):
    Xs, Ys = zip(*batch)
    Xs_padded = pad_sequence(Xs, batch_first=True, padding_value=0)
    Ys_padded = pad_sequence(Ys, batch_first=True, padding_value=0)
    return Xs_padded, Ys_padded

In [3]:
class Config:
    """Configuration class for GPT model parameters."""
    def __init__(self, vocab_size, block_size, n_embd, n_head, n_layer):
        self.vocab_size = vocab_size
        self.block_size = block_size
        self.n_embd = n_embd
        self.n_head = n_head
        self.n_layer = n_layer
        self.embd_pdrop = 0.1
        self.resid_pdrop = 0.1
        self.attn_pdrop = 0.1

class CausalSelfAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        assert self.config.n_embd % self.config.n_head == 0, "Embedding dimension must be divisible by number of heads."
        
        # Key, query, and value projections
        self.key = nn.Linear(self.config.n_embd, self.config.n_embd)
        self.query = nn.Linear(self.config.n_embd, self.config.n_embd)
        self.value = nn.Linear(self.config.n_embd, self.config.n_embd)
        
        # Dropout layers
        self.attn_drop = nn.Dropout(self.config.attn_pdrop)
        self.resid_drop = nn.Dropout(self.config.resid_pdrop)
        
        # Output projection
        self.proj = nn.Linear(self.config.n_embd, self.config.n_embd)
        
        # Causal mask to prevent attention to future tokens
        self.register_buffer("mask", torch.tril(torch.ones((config.block_size, config.block_size))).unsqueeze(0).unsqueeze(1))

    def forward(self, x):
        B, T, C = x.size()
        # Calculate query, key, values for all heads and transpose
        k = self.key(x).view(B, T, self.config.n_head, C // self.config.n_head).transpose(1, 2)
        q = self.query(x).view(B, T, self.config.n_head, C // self.config.n_head).transpose(1, 2)
        v = self.value(x).view(B, T, self.config.n_head, C // self.config.n_head).transpose(1, 2)
        
        # Ensure mask is correctly broadcasted for the batch and heads
        mask = self.mask[:, :, :T, :T]
        
        # Scaled dot product attention with causality
        att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(C // self.config.n_head))
        att = att.masked_fill(mask == 0, float('-inf'))
        att = F.softmax(att, dim=-1)
        y = (att @ v).transpose(1, 2).reshape(B, T, C)
        
        # Apply dropout and projection
        y = self.resid_drop(self.proj(y))
        return y


class TransformerBlock(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.attn = CausalSelfAttention(config)
        self.ln1 = nn.LayerNorm(config.n_embd)
        self.mlp = nn.Sequential(
            nn.Linear(config.n_embd, 4 * config.n_embd),
            nn.GELU(),
            nn.Linear(4 * config.n_embd, config.n_embd),
            nn.Dropout(config.resid_pdrop),
        )
        self.ln2 = nn.LayerNorm(config.n_embd)
    
    def forward(self, x):
        x = x + self.attn(self.ln1(x))
        x = x + self.mlp(self.ln2(x))
        return x

class TransformerModel(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        
        # Input embedding stem
        self.tok_emb = nn.Embedding(self.config.vocab_size, self.config.n_embd)
        self.pos_emb = nn.Parameter(torch.zeros(1, self.config.block_size, self.config.n_embd))
        self.drop = nn.Dropout(self.config.embd_pdrop)
        
        # Transformer blocks
        self.blocks = nn.Sequential(*[TransformerBlock(self.config) for _ in range(self.config.n_layer)])
        
        # Decoder head
        self.ln_f = nn.LayerNorm(self.config.n_embd)
        self.head = nn.Linear(self.config.n_embd, self.config.vocab_size, bias=False)
        
        self.apply(self._init_weights)
    
    def _init_weights(self, module):
        if isinstance(module, (nn.Linear, nn.Embedding)):
            module.weight.data.normal_(mean=0.0, std=0.02)
            if isinstance(module, nn.Linear) and module.bias is not None:
                module.bias.data.zero_()
        elif isinstance(module, nn.LayerNorm):
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)
    
    def forward(self, idx):
        B, T = idx.size()
        assert T <= self.config.block_size, "Input Sequence Too Long."
        
        token_embeddings = self.tok_emb(idx)
        position_embeddings = self.pos_emb[:, :T, :]
        x = self.drop(token_embeddings + position_embeddings)
        x = self.blocks(x)
        x = self.ln_f(x)
        logits = self.head(x)
        
        return logits

In [4]:
token_to_idx['<pad>'] = 0  # Padding token

vocab_size = 70   
block_size = 201
embed_size = 512
num_heads = 8
num_layers = 8
dropout = 0.1

In [5]:
path = 'filtered_sequence.pkl'
with open(path, 'rb') as file:
    processed = pickle.load(file)

In [6]:
d = len(processed)
train_ratio = 0.8
valid_ratio = 0.1

train = processed[:int(train_ratio * d)]
validation = processed[int(train_ratio * d):int((train_ratio + valid_ratio) * d) ]

test_exact = processed[int((train_ratio + valid_ratio) * d): ]
# test_valid = valid[int((train_ratio + valid_ratio) * d): ]

In [7]:
train_dataset = SequenceDataset(train, token_to_idx)
valid_dataset = SequenceDataset(validation, token_to_idx)

test_exact_dataset = SequenceDataset(test_exact, token_to_idx)
# test_valid_dataset = ValidDataset(test_valid, token_to_idx)

In [8]:
def validate_model(model, valid_loader, criterion):
    model.eval()

    total_loss_sum = 0.0
    total_valid_positions = 0

    with torch.no_grad():
        for X_batch, Y_batch in valid_loader:
            logits = model(X_batch)
            logits = logits.view(-1, logits.size(-1))  # Shape: [batch_size * seq_length, vocab_size]
            Y_batch = Y_batch.view(-1)  # Shape: [batch_size * seq_length]

            # Assuming the padding token index is 0
            padding_token_index = 0
            mask = (Y_batch != padding_token_index).float()  # Create a mask for valid positions

            loss = criterion(logits, Y_batch)  # Calculate loss without reduction
            masked_loss = loss * mask  # Apply mask
            loss_sum = masked_loss.sum()  # Sum the losses at valid positions
            valid_positions = mask.sum()  # Count valid positions

            total_loss_sum += loss_sum.item()
            total_valid_positions += valid_positions.item()

    # Calculate the average loss across all valid positions
    average_loss = total_loss_sum / total_valid_positions if total_valid_positions > 0 else 0
    return average_loss

In [9]:
def training_function():

    epochs = 15
    
    criterion = nn.CrossEntropyLoss(ignore_index=0, reduction = 'none')
    accelerator = Accelerator()

    train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, collate_fn=collate_fn)
    valid_loader = DataLoader(valid_dataset, batch_size=32, shuffle=True, collate_fn=collate_fn)

    config = Config(vocab_size, block_size, n_layer=8, n_head=8, n_embd=512)
    model = TransformerModel(config)

    optimizer = torch.optim.AdamW(model.parameters(), lr=0.0001, weight_decay=0.01)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max = epochs)
    
    train_loader, valid_loader, model, scheduler, optimizer = accelerator.prepare(train_loader, valid_loader, model, scheduler, optimizer)

    epoch_loss = []
    epoch = 0
    
    for epoch in range(epochs):

        model.train()
        
        total_loss = 0.0
        total_data = 0

        progress_bar = tqdm(train_loader, total=len(train_loader), desc=f"Epoch {epoch}")

        for X_batch, Y_batch in progress_bar:
            
            optimizer.zero_grad()
            logits = model(X_batch)

            logits = logits.view(-1, logits.size(-1))  # Shape: [batch_size * seq_length, vocab_size]
            Y_batch = Y_batch.view(-1)  # Shape: [batch_size * seq_length]

            padding_token_index = 0  # Assuming the padding token index is 0
            mask = (Y_batch != padding_token_index).float()
            
            loss = criterion(logits, Y_batch)

            masked_loss = loss * mask
            loss_sum = masked_loss.sum()
            valid_positions = mask.sum()

            loss = loss_sum / valid_positions
            accelerator.backward(loss)
            optimizer.step()

            total_loss += loss_sum.item() 
            total_data += valid_positions.item()

            progress_bar.set_description(f"Epoch {epoch}, Avg Loss: {total_loss/total_data:.4f}")

        valid_loss = validate_model(model, valid_loader, criterion)
        print(f"Validation Loss: {valid_loss}")
                
        scheduler.step()
        progress_bar.close()

        epoch_loss.append((total_loss/total_data, valid_loss))

        if accelerator.is_main_process:
            model_save_path = f"Model_{epoch+1}.pth"
            accelerator.save(model.state_dict(), model_save_path)

    with open("Loss History.pkl", "wb") as f:
        pickle.dump(epoch_loss, f)


In [None]:
notebook_launcher(training_function, num_processes = 1)

In [12]:
if torch.cuda.is_available():
    device = torch.device("cuda")  
else:
    device = torch.device("cpu")

print("Device: ", device)

config = Config(vocab_size, block_size, n_layer=8, n_head=8, n_embd=512)
model = TransformerModel(config)

state_dict = torch.load('Model_4.pth', map_location=device)
state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()}

model.load_state_dict(state_dict)
model.to(device)

Device:  cuda


TransformerModel(
  (tok_emb): Embedding(70, 512)
  (drop): Dropout(p=0.1, inplace=False)
  (blocks): Sequential(
    (0): TransformerBlock(
      (attn): CausalSelfAttention(
        (key): Linear(in_features=512, out_features=512, bias=True)
        (query): Linear(in_features=512, out_features=512, bias=True)
        (value): Linear(in_features=512, out_features=512, bias=True)
        (attn_drop): Dropout(p=0.1, inplace=False)
        (resid_drop): Dropout(p=0.1, inplace=False)
        (proj): Linear(in_features=512, out_features=512, bias=True)
      )
      (ln1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
      (mlp): Sequential(
        (0): Linear(in_features=512, out_features=2048, bias=True)
        (1): GELU(approximate='none')
        (2): Linear(in_features=2048, out_features=512, bias=True)
        (3): Dropout(p=0.1, inplace=False)
      )
      (ln2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
    )
    (1): TransformerBlock(
      (attn): CausalS

In [None]:
correct_predictions = 0
total_predictions = 0

padding_index = token_to_idx['<pad>']  # Assuming you have a token_to_idx mapping that includes '<pad>'
test_exact_loader = DataLoader(test_exact_dataset, batch_size=32, shuffle=True, collate_fn=collate_fn)

model.eval()
with torch.no_grad():
    # Wrap data_loader with tqdm for a progress bar
    for X, Y_true in tqdm(test_exact_loader, desc="Evaluating", unit="batch"):
        X, Y_true = X.to(device), Y_true.to(device)
        logits = model(X)   

        for i in range(Y_true.shape[1]):  # Iterate over even indices
            logits_slice = logits[:, i, :]
            probabilities = F.softmax(logits_slice, dim=-1)
            Y_pred = torch.argmax(probabilities, dim=-1)

            # Identify non-padded positions in Y_true at position i
            non_padded_positions = Y_true[:, i] != padding_index

            # Update correct predictions considering only non-padded positions
            correct_predictions += ((Y_pred == Y_true[:, i]) & non_padded_positions).sum().item()
            # Update total predictions to exclude padded positions
            total_predictions += non_padded_positions.sum().item()

accuracy = correct_predictions / total_predictions if total_predictions > 0 else 0
print(f"Accuracy: {accuracy}")
