In [19]:
import pandas as pd
import string

# Specify the path to your TSV file
train_tsv_file_path = '/tmp/cyc/Train_GCC-training.tsv'
val_tsv_file_path = '/tmp/cyc/Validation_GCC-1.1.0-Validation.tsv'

# Read the TSV file into a DataFrame
train_df = pd.read_csv(train_tsv_file_path, delimiter='\t', header=None)[0]
val_df = pd.read_csv(val_tsv_file_path, delimiter='\t', header=None)[0]

def remove_spaces(sentence):
    for punctuation in string.punctuation:
        sentence = sentence.replace(f' {punctuation}', punctuation)
    return ' '.join(sentence.split())

train_df = train_df.apply(remove_spaces)
val_df = val_df.apply(remove_spaces)

In [20]:
import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import T5Tokenizer

class Seq2SeqDataset(Dataset):
    def __init__(self, dataframe, tokenizer, processor, max_length=64):
        self.data = dataframe
        self.tokenizer = tokenizer
        self.processor = processor
        self.max_length = max_length

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        source_sentence = self.data.iloc[idx]
        
        # Tokenize and encode the source sentence
        t5_tokens = self.tokenizer.encode_plus(
            source_sentence,
            add_special_tokens=True,
            max_length=self.max_length,
            return_tensors='pt',
            padding='max_length',
            truncation=True
        )

        t5_inputs =  {
            'input_ids': t5_tokens['input_ids'].squeeze(),
            'attention_mask': t5_tokens['attention_mask'].squeeze(),
            'target_ids': t5_tokens['input_ids'].squeeze(),  # Target is the same as the input
            'target_mask': t5_tokens['attention_mask'].squeeze(),
            'target': source_sentence
        }

        clip_tokens = self.processor(
            text=source_sentence, 
            images=torch.zeros((3, 224, 224)), 
            return_tensors="pt", 
            padding='max_length', 
            max_length=self.max_length, 
            truncation=True
        )

        clip_inputs = {
            'input_ids': clip_tokens['input_ids'].squeeze(),
            'attention_mask': clip_tokens['attention_mask'].squeeze(),
            'pixel_values': clip_tokens["pixel_values"].view(3, 224, 224),
            'target_ids': clip_tokens['input_ids'].squeeze(),  # Target is the same as the input
            'target_mask': clip_tokens['attention_mask'].squeeze(),
            'target': source_sentence
        }

        return t5_inputs, clip_inputs

In [24]:
import torch
from torch import nn, optim
from torch.utils.data import DataLoader
from transformers import T5ForConditionalGeneration, T5Tokenizer, CLIPModel
from tqdm import tqdm
        
class Bottleneck(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(Bottleneck, self).__init__()
        self.layer = nn.Linear(input_dim, output_dim)
        self.norm = nn.LayerNorm(output_dim)

    def forward(self, x):
        return self.norm(self.layer(x))

class ClipEval(nn.Module):
    def __init__(self, t5_model_path):
        super(Bottleneck, self).__init__()
        self.encoder = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
        self.bottleneck = Bottleneck(512, 768)
        self.decoder = T5ForConditionalGeneration.from_pretrained('t5-base')
        self.decoder.load_state_dict(torch.load(t5_model_path))

    def forward(self, clip_inputs, t5_inputs, train=True):
        if train:
            pass
        else:
            

In [25]:
device = 'cuda'
# Load the T5 tokenizer
tokenizer = T5Tokenizer.from_pretrained("t5-base")

# Create the dataset and DataLoader
train_dataset = Seq2SeqDataset(train_df, tokenizer)
val_dataset = Seq2SeqDataset(val_df, tokenizer)
train_dataloader = DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=4)
val_dataloader = DataLoader(val_dataset, batch_size=64, shuffle=False, num_workers=4)

# Initialize the autoencoder model
autoencoder_model = Autoencoder().to(device)

# Define the optimizer and learning rate scheduler
optimizer = optim.AdamW(autoencoder_model.parameters(), lr=5e-5)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.9)

# Training loop
num_epochs = 3

for epoch in range(num_epochs):
    total_loss = 0
    autoencoder_model.train()

    for batch in tqdm(train_dataloader, desc=f'Epoch {epoch + 1}/{num_epochs}'):
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        target_ids = batch['target_ids'].to(device)
        target_mask = batch['target_mask'].to(device)

        loss = autoencoder_model(input_ids, attention_mask, target_ids, target_mask)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    average_loss = total_loss / len(train_dataloader)
    print(f'Epoch {epoch + 1}/{num_epochs}, Average Loss: {average_loss}')

    # Optionally update the learning rate
    scheduler.step()

    # Evaluate with Exact Match (EM) on a validation set
    autoencoder_model.eval()
    best_em_score = 0.0
    with torch.no_grad():
        em_count = 0
        total_samples = 0

        for val_batch in tqdm(val_dataloader, desc=f'Validation - Epoch {epoch + 1}'):
            input_ids = val_batch['input_ids'].to(device)
            attention_mask = val_batch['attention_mask'].to(device)
            target_ids = val_batch['target_ids'].to(device)
            target_mask = val_batch['target_mask'].to(device)

            # Generate sequences
            generated_ids = autoencoder_model(input_ids, attention_mask).cpu().numpy()

            # Decode token IDs to strings
            generated_sentences = [tokenizer.decode(ids, skip_special_tokens=True) for ids in generated_ids]
            target_sentences = val_batch['target']

            # Check for exact match
            em_count += sum(1 for gen, target in zip(generated_sentences, target_sentences) if gen == target)
            total_samples += len(generated_sentences)

        em_score = em_count / total_samples
        print(f'Validation EM Score: {em_score}')

        # Save the model if the EM score improves
        if em_score > best_em_score:
            best_em_score = em_score
            torch.save(autoencoder_model.state_dict(), 't5_model.pth')
            print("Model saved!")



Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
model.safetensors: 100%|██████████| 892M/892M [00:36<00:00, 24.2MB/s] 
generation_config.json: 100%|██████████| 147/147 [00:00<00:00, 1.45MB/s]
Epoch 1/3:   1%|          | 302/51849 [00:59<2:49:48,  5.06it/s]


KeyboardInterrupt: 

In [14]:
!pip install sentencepiece

Collecting sentencepiece
  Downloading sentencepiece-0.1.99-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.3 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.3/1.3 MB[0m [31m608.7 kB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0m
[?25hInstalling collected packages: sentencepiece
Successfully installed sentencepiece-0.1.99
