In [None]:
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor
import numpy as np
import pandas as pd
import torch.nn.functional as F
import os
from tqdm import tqdm, trange
from transformers import GPT2Tokenizer, GPT2LMHeadModel, AdamW



In [None]:
# Get cpu, gpu or mps device for training.
device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)
print(f"Using {device} device")

In [None]:
#Data loaders Pytorch dataset class



In [1]:

# Hyperparameters
d_model = 512  # Dimension of the model
output_dim = 32  # Output dimension of the MLP
num_layers = 6  # Number of transformer layers
num_heads = 8  # Number of heads in multi-headed attention
dim_feedforward = 2048  # Dimension of the feedforward network
dropout = 0.1  # Dropout rate

In [None]:
# Initialize GPT-2 model and tokenizer
gpt2_model = GPT2LMHeadModel.from_pretrained('gpt2').to(device)
gpt2_model.eval()  # Freeze the GPT-2 model
gpt2_tokenizer = GPT2Tokenizer.from_pretrained('gpt2')

# Freeze all the parameters
for param in gpt2_model.parameters():
    param.requires_grad = False

In [None]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model):
        super(PositionalEncoding, self).__init__()
        self.d_model = d_model

    def forward(self, x):
        L, N = x.size(10), x.size(0)
        pos = torch.arange(L).unsqueeze(0).repest(N, 1).to(x.device)
        pos_embedding = self.calc_pos_embedding(pos)
        return x + pos_embedding
    def calc_pos_embedding(self, pos):
        pos = pos.float()
        factor = torch.exp(-torch.arange(0, self.d_model, 2).float() * (torch.log(torch.tensor(10000.0)) / self.d_model))
        sinusoid_inp = torch.ger(pos, factor)
        pos_embedding = torch.cat([sinusoid_inp.sin(), sinusoid_inp.cos()], dim=-1)
        return pos_embedding

In [None]:
class TransformerDecoder(nn.Module):
    def __init__(self, d_model, num_heads, dim_feedforward, dropout, num_layers):
        super(TransformerDecoder, self).__init__()
        decoder_layer = nn.TransformerDecoderLayer(d_model, num_heads, dim_feedforward, dropout)
        self.transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers)
        self.output_layer = nn.Linear(d_model, output_dim)
    
    def forward(self, x, tgt):
        tgt = tgt.permute(1, 0, 2)
        output = self.transformer_decoder(tgt, x)
        output = self.output_layer(output)
        return output


In [None]:
class MLP(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(MLP, self).__init__()
        self.fc1 = nn.Linear(input_dim, output_dim)

    def forward(self, x):
        return self.fc1(x)


In [None]:
class MyModel(nn.Module):
    def __init__(self, d_model, num_heads, dim_feedforward, dropout, num_layers, output_dim):
        super(MyModel, self).__init__()
        self.positional_encoding = PositionalEncoding(d_model)
        self.decoder = TransformerDecoder(d_model, num_heads, dim_feedforward, dropout, num_layers)
        self.mlp = MLP(d_model, output_dim)
    
    def forward(self, x):
        x = self.positional_encoding(x)
        x = self.decoder(x)
        x = self.mlp(x)
        return x


In [None]:
# Loss function
loss_fn = F.cross_entropy()

# Your training loop here
# where you calculate the loss as per your requirement and perform backpropagation


def Train(dataset: "Add dataset", model: MyModel, args,
               lr: float = 2e-5, warmup_steps: int = 5000, output_dir: str = ".", output_prefix: str = ""):

    num_epochs = 10

    for epoch in range(num_epochs):
        model.train()
        for i, (input_data, target) in enumerate(train_loader):
            input_data, target = input_data.to(device), target.to(device)

            # Forward pass through your model
            output = model(input_data)

            # Tokenize the output of your model
            tokens = gpt2_tokenizer(output.tolist(), return_tensors='pt').to(device)

            # Forward pass through GPT-2 model
            gpt2_output = gpt2_model(tokens).last_hidden_state

            # Remove the first 32 tokens from the GPT-2 output
            gpt2_output = gpt2_output[:, 32:]

            # Calculate loss
            loss = loss_fn(gpt2_output, target)

            # Backpropagate loss and update parameters
            optimizer = AdamW(model.parameters(), lr=lr)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        print(f'Epoch: {epoch}, Loss: {loss.item()}')

In [None]:
def train(dataset: ClipCocoDataset, model: ClipCaptionModel, args,
          lr: float = 2e-5, warmup_steps: int = 5000, output_dir: str = ".", output_prefix: str = ""):

    device = torch.device('cuda:0')
    batch_size = args.bs
    epochs = args.epochs
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
    model = model.to(device)
    model.train()
    optimizer = AdamW(model.parameters(), lr=lr)
    train_dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=True)
    scheduler = get_linear_schedule_with_warmup(
        optimizer, num_warmup_steps=warmup_steps, num_training_steps=epochs * len(train_dataloader)
    )
    # save_config(args)
    for epoch in range(epochs):
        print(f">>> Training epoch {epoch}")
        sys.stdout.flush()
        progress = tqdm(total=len(train_dataloader), desc=output_prefix)
        for idx, (tokens, mask, prefix) in enumerate(train_dataloader):
            model.zero_grad()
            tokens, mask, prefix = tokens.to(device), mask.to(device), prefix.to(device, dtype=torch.float32)
            outputs = model(tokens, prefix, mask)
            logits = outputs.logits[:, dataset.prefix_length - 1: -1]
            loss = nnf.cross_entropy(logits.reshape(-1, logits.shape[-1]), tokens.flatten(), ignore_index=0)
            loss.backward()
            optimizer.step()
            scheduler.step()
            optimizer.zero_grad()
            progress.set_postfix({"loss": loss.item()})
            progress.update()
            if (idx + 1) % 10000 == 0:
                torch.save(
                    model.state_dict(),
                    os.path.join(output_dir, f"{output_prefix}_latest.pt"),
                )
        progress.close()
        if epoch % args.save_every == 0 or epoch == epochs - 1:
            torch.save(
                model.state_dict(),
                os.path.join(output_dir, f"{output_prefix}-{epoch:03d}.pt"),
            )
    return model