Encoder-Decoder Layers

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split
import math
import csv
import ast
import matplotlib.pyplot as plt
import os
from torch.cuda.amp import autocast, GradScaler

os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# --------------------- Dataset ---------------------
class SequenceClassificationDataset(Dataset):
    def __init__(self, csv_file):
        self.inputs = []
        self.outputs = []

        with open(csv_file, 'r') as f:
            reader = csv.reader(f)
            for row in reader:
                input_str, output_str = row[0].split('][')
                input_list = ast.literal_eval(input_str + ']')
                output_list = ast.literal_eval('[' + output_str)
                self.inputs.append(input_list)
                self.outputs.append(output_list)

        self.inputs = torch.tensor(self.inputs, dtype=torch.long)
        self.outputs = torch.tensor(self.outputs, dtype=torch.long)
        self.vocab_size = max(torch.max(self.inputs), torch.max(self.outputs)) + 1

    def __len__(self):
        return len(self.inputs)

    def __getitem__(self, idx):
        return self.inputs[idx], self.outputs[idx]

# ------------------ Token Weights ------------------
def compute_token_weights(dataset, vocab_size):
    token_counts = torch.bincount(dataset.outputs.view(-1), minlength=vocab_size)
    freq = token_counts.float() + 1e-6
    weights = 1.0 / torch.sqrt(freq)
    weights = weights / weights.sum() * vocab_size
    return weights.to(device)

# --------------------- Positional Encoding ---------------------
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout=0.1, max_len=5000):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:x.size(0), :]
        return self.dropout(x)

# --------------------- Seq2Seq Transformer ---------------------
class TransformerSeq2Seq(nn.Module):
    def __init__(self, vocab_size, d_model=256, nhead=4, num_encoder_layers=4, num_decoder_layers=4, dim_feedforward=1024, dropout=0.4):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.pos_encoder = PositionalEncoding(d_model, dropout)
        self.pos_decoder = PositionalEncoding(d_model, dropout)

        encoder_layer = nn.TransformerEncoderLayer(d_model, nhead, dim_feedforward, dropout, batch_first=False)
        decoder_layer = nn.TransformerDecoderLayer(d_model, nhead, dim_feedforward, dropout, batch_first=False)

        self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_encoder_layers)
        self.decoder = nn.TransformerDecoder(decoder_layer, num_layers=num_decoder_layers)
        self.output_layer = nn.Linear(d_model, vocab_size)

    def forward(self, src, tgt):
        src = self.embedding(src).transpose(0, 1)
        tgt = self.embedding(tgt).transpose(0, 1)

        src = self.pos_encoder(src)
        tgt = self.pos_decoder(tgt)

        memory = self.encoder(src)
        output = self.decoder(tgt, memory)
        return self.output_layer(output).transpose(0, 1)

# --------------------- Training ---------------------
def train_model(model, train_loader, val_loader, weights, epochs=50, patience=5):
    model.to(device)
    criterion = nn.CrossEntropyLoss(weight=weights, ignore_index=2)
    optimizer = optim.AdamW(model.parameters(), lr=1e-4)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=5, factor=0.5)
    scaler = GradScaler()

    best_val_loss = float('inf')
    epochs_no_improve = 0
    train_losses, val_losses = [], []

    for epoch in range(epochs):
        model.train()
        train_loss, total_token, correct_token, exact_match = 0, 0, 0, 0

        for inputs, targets in train_loader:
            inputs, targets = inputs.to(device), targets.to(device)
            tgt_input = targets[:, :-1]
            tgt_output = targets[:, 1:]

            optimizer.zero_grad()
            with autocast():
                logits = model(inputs, tgt_input)
                loss = criterion(logits.reshape(-1, logits.size(-1)), tgt_output.reshape(-1))

            scaler.scale(loss).backward()
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            scaler.step(optimizer)
            scaler.update()

            train_loss += loss.item()
            preds = logits.argmax(dim=-1)
            mask = tgt_output != 2  # 2 is the index of <pad>
            correct_token += ((preds == tgt_output) & mask).sum().item()
            total_token += mask.sum().item()
            match = (preds == tgt_output) | ~mask  # The position of <pad> does not need to be predicted correctly
            exact_match += match.all(dim=1).sum().item()

        token_acc = correct_token / total_token
        seq_acc = exact_match / len(train_loader.dataset)
        avg_train_loss = train_loss / len(train_loader)

        model.eval()
        val_loss, val_total, val_correct, val_exact = 0, 0, 0, 0
        with torch.no_grad():
            for inputs, targets in val_loader:
                inputs, targets = inputs.to(device), targets.to(device)
                tgt_input = targets[:, :-1]
                tgt_output = targets[:, 1:]

                with autocast():
                    logits = model(inputs, tgt_input)
                    loss = criterion(logits.reshape(-1, logits.size(-1)), tgt_output.reshape(-1))

                val_loss += loss.item()
                preds = logits.argmax(dim=-1)
                mask = tgt_output != 2
                val_correct += ((preds == tgt_output) & mask).sum().item()
                val_total += mask.sum().item()
                match = (preds == tgt_output) | ~mask  #<pad> position is not counted towards matching requirement
                val_exact += match.all(dim=1).sum().item()


        val_loss /= len(val_loader)
        val_token_acc = val_correct / val_total
        val_seq_acc = val_exact / len(val_loader.dataset)
        scheduler.step(val_loss)

        train_losses.append(avg_train_loss)
        val_losses.append(val_loss)
        lr = optimizer.param_groups[0]['lr']

        if val_loss < best_val_loss:
            best_val_loss = val_loss
            epochs_no_improve = 0
            torch.save(model.state_dict(), 'E:/Language Model Project/final-model/best_model.pth')
        else:
            epochs_no_improve += 1
            if epochs_no_improve >= patience:
                print(f"Early stopping at epoch {epoch+1}")
                break
            

        print(f"Epoch [{epoch+1}/{epochs}] | Train Loss: {avg_train_loss:.4f} | Val Loss: {val_loss:.4f} | LR: {lr:.5f}")
        print(f"    TokenAcc: train={token_acc:.4f}, val={val_token_acc:.4f} | SeqAcc: train={seq_acc:.4f}, val={val_seq_acc:.4f}")

    plt.figure(figsize=(10, 6))
    plt.plot(train_losses, label='Training Loss')
    plt.plot(val_losses, label='Validation Loss', linestyle='--')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('Loss Curve')
    plt.legend()
    plt.grid(True)
    plt.savefig('loss_curve.png')
    plt.show()

# --------------------- Main ---------------------
if __name__ == "__main__":
    dataset = SequenceClassificationDataset('E:/Language Model Project/final-model/Data/train_15-20_10_shuffle.csv')#_20_shuffle
    weights = compute_token_weights(dataset, dataset.vocab_size)
    train_size = int(0.8 * len(dataset))
    val_size = len(dataset) - train_size
    train_dataset, val_dataset = random_split(dataset, [train_size, val_size])
    print("dataset had been loaded")

    train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True, num_workers=0)
    val_loader = DataLoader(val_dataset, batch_size=128, num_workers=0)

    model = TransformerSeq2Seq(
        vocab_size=dataset.vocab_size,
        d_model=256,
        nhead=4,
        num_encoder_layers=4,
        num_decoder_layers=4,
        dim_feedforward=1024,
        dropout=0.3
    )

    train_model(model, train_loader, val_loader, weights, epochs=50, patience=5)

KeyboardInterrupt: 