<a href="https://colab.research.google.com/github/Eupham/-/blob/master/untitled24.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:

!pip install datasets
import os
import time
import pickle
import torch
import torch.nn as nn
from torch.nn import TransformerEncoder, TransformerEncoderLayer, TransformerDecoder, TransformerDecoderLayer
import torch.optim as optim
from tqdm import tqdm
from torch.utils.data import DataLoader, Dataset
from collections import namedtuple
from sklearn.preprocessing import LabelEncoder
from torch.nn.utils.rnn import pad_sequence
from torchtext.data import get_tokenizer
from datasets import load_dataset

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Paths to save and load files
model_path = '/content/checkpoint_epoch0_batch1000.pth'
vocab_path = '/vocab.pkl'
word_to_index_path = '/word_to_index.pkl'

# 1. Embedding Layer
class EmbeddingLayer(nn.Module):
    def __init__(self, vocab_size, d_model):
        super(EmbeddingLayer, self).__init__()
        self.embedding = nn.Embedding(vocab_size, d_model)

    def forward(self, x):
        return self.embedding(x)

# 2. Transformer Encoder
class TransformerEncoderModule(nn.Module):
    def __init__(self, d_model, nhead, num_layers, dim_feedforward):
        super(TransformerEncoderModule, self).__init__()
        encoder_layers = TransformerEncoderLayer(d_model, nhead, dim_feedforward)
        self.transformer_encoder = TransformerEncoder(encoder_layers, num_layers)

    def forward(self, x):
        return self.transformer_encoder(x)

# 3. Classification Head
class ClassificationHead(nn.Module):
    def __init__(self, d_model, num_classes):
        super(ClassificationHead, self).__init__()
        self.gru = nn.GRU(d_model, d_model // 2, batch_first=True, bidirectional=True)
        self.leaky_relu = nn.LeakyReLU()
        self.fc = nn.Linear(d_model, num_classes)
        self.softmax = nn.Softmax(dim=1)

    def forward(self, x):
        # Apply GRU
        x, _ = self.gru(x)

        # Apply Leaky ReLU activation
        x = self.leaky_relu(x)

        # Apply fully connected layer
        x = self.fc(x)

        # Apply softmax
        return self.softmax(x)



# 4. Sequence Generation Head
class SequenceGenerationHead(nn.Module):
    def __init__(self, d_model, nhead, num_layers, dim_feedforward, vocab_size):
        super(SequenceGenerationHead, self).__init__()
        decoder_layers = TransformerDecoderLayer(d_model, nhead, dim_feedforward)
        self.transformer_decoder = TransformerDecoder(decoder_layers, num_layers)
        self.fc_out = nn.Linear(d_model, vocab_size)

    def forward(self, x, memory):
        output = self.transformer_decoder(x, memory)
        return self.fc_out(output)

# 5. Complete Model
class TransformerModel(nn.Module):
    def __init__(self, vocab_size, d_model, nhead, num_layers, dim_feedforward, num_classes):
        super(TransformerModel, self).__init__()
        self.embedding = EmbeddingLayer(vocab_size, d_model)
        self.encoder = TransformerEncoderModule(d_model, nhead, num_layers, dim_feedforward)
        self.classification_head = ClassificationHead(d_model, num_classes)
        self.sequence_generation_head = SequenceGenerationHead(d_model, nhead, num_layers, dim_feedforward, vocab_size)

    def forward(self, x, task='classification'):
        embedded = self.embedding(x)
        encoded = self.encoder(embedded)
        if task == 'classification':
            return self.classification_head(encoded)
        elif task == 'generation':
            return self.sequence_generation_head(embedded, encoded)

# 6. Dataset with Fake Target Column
squad_dataset = load_dataset("squad")
contexts = [item['context'] for item in squad_dataset['train']]
questions = [item['question'] for item in squad_dataset['train']]
titles = [item['title'] for item in squad_dataset['train']]
label_encoder = LabelEncoder()
encoded_titles = label_encoder.fit_transform(titles)
sentences = list(zip(contexts, encoded_titles, questions))
tokenizer = get_tokenizer("basic_english")
tokenized_sentences = [tokenizer(sentence) for sentence, _, _ in sentences]
tokenized_targets = [tokenizer(target) for _, _, target in sentences]
vocab = set(token for sentence in tokenized_sentences + tokenized_targets for token in sentence)
vocab_size = len(vocab)
word_to_index = {word: index for index, word in enumerate(vocab)}
indexed_sentences = [[word_to_index[token] for token in sentence] for sentence in tokenized_sentences]
indexed_targets = [[word_to_index[token] for token in target] for target in tokenized_targets]
Sentence = namedtuple('Sentence', ['text', 'label', 'target'])
dataset = [Sentence(torch.tensor([word_to_index[token] for token in sentence]), label, torch.tensor([word_to_index[token] for token in target])) for sentence, label, target in zip(tokenized_sentences, [label for _, label, _ in sentences], tokenized_targets)]

# Custom Dataset class
class SentenceDataset(Dataset):
    def __init__(self, sentences):
        self.sentences = sentences

    def __len__(self):
        return len(self.sentences)

    def __getitem__(self, idx):
        return self.sentences[idx]

# Collate function for padding
def collate_fn(batch):
    texts = [item.text for item in batch]
    labels = torch.tensor([item.label for item in batch])
    targets = [item.target for item in batch]
    texts = pad_sequence(texts, batch_first=True, padding_value=0)
    targets = pad_sequence(targets, batch_first=True, padding_value=0)
    return texts, labels, targets

# DataLoader
dataloader = DataLoader(SentenceDataset(dataset), batch_size=2, collate_fn=collate_fn)

import time
import pickle

# 7. Training Loop
d_model = 1536  # Model dimension, should be divisible by nhead
nhead = 12     # Number of attention heads in each layer
num_layers = 12 # Number of layers
dim_feedforward = 3072 # Dimension of the feedforward network model
num_classes = len(set(encoded_titles))
num_epochs = 1

model = TransformerModel(vocab_size, d_model, nhead, num_layers, dim_feedforward, num_classes)
model.to(device) # Move the model to the device
optimizer = optim.Adam(model.parameters())

classification_loss_fn = nn.CrossEntropyLoss()
generation_loss_fn = nn.CrossEntropyLoss()

# Initialize a variable to keep track of the last checkpoint time
last_checkpoint_time = time.time()

# Define the checkpoint interval
checkpoint_interval = 200

# Wrap the range with tqdm to show progress bar
for epoch in tqdm(range(num_epochs), desc='Epochs'):
    # Create an inner tqdm loop for batches, with set_postfix enabled
    batch_iterator = tqdm(enumerate(dataloader), total=len(dataloader), desc='Batches', leave=False)
    for batch_idx, batch in batch_iterator:
        input_data, labels, targets = batch
        input_data, labels, targets = input_data.to(device), labels.to(device), targets.to(device) # Move to device

        # Forward pass for classification
        classification_output = model(input_data, task='classification')
        classification_loss = classification_loss_fn(classification_output[:, -1, :], labels)

        # Optimization for classification
        optimizer.zero_grad()
        classification_loss.backward(retain_graph=True) # Retain the graph for next step
        optimizer.step()

        # Forward pass for sequence generation
        generation_output = model(input_data, task='generation')

        # Reshape the generation_output and targets
        generation_output_reshaped = generation_output.view(-1, vocab_size)
        targets_aligned = targets[:, :generation_output.size(1)]
        targets_reshaped = targets_aligned.view(-1)

        # Ensure the generation_output and targets have the same shape
        if generation_output_reshaped.size(0) != targets_reshaped.size(0):
            generation_output_reshaped = generation_output_reshaped[:targets_reshaped.size(0), :]

        generation_loss = generation_loss_fn(generation_output_reshaped, targets_reshaped)

        # Optimization for sequence generation
        optimizer.zero_grad()
        generation_loss.backward() # No need to retain the graph here
        optimizer.step()

        # Update the postfix of the tqdm loop with the current losses
        batch_iterator.set_postfix({
            'classification_loss': classification_loss.item(),
            'generation_loss': generation_loss.item(),
        })

        # Checkpoint the model every checkpoint_interval batches
        if batch_idx % checkpoint_interval == 0:
            checkpoint_path = f'checkpoint_epoch{epoch}_batch{batch_idx}.pth'
            torch.save(model.state_dict(), checkpoint_path)