In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from transformers import T5Tokenizer, T5Model
from datasets import load_dataset
from torch.utils.data import Dataset, DataLoader

class CustomT5EncoderDecoder(nn.Module):
    def __init__(self, model_name="t5-small", d_model=512, num_heads=8, num_layers=6):
        super(CustomT5EncoderDecoder, self).__init__()

        self.tokenizer = T5Tokenizer.from_pretrained(model_name)
        self.t5_model = T5Model.from_pretrained(model_name)

        self.embedding = self.t5_model.shared

        self.encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=num_heads)
        self.encoder = nn.TransformerEncoder(self.encoder_layer, num_layers=num_layers)

        self.decoder_layer = nn.TransformerDecoderLayer(d_model=d_model, nhead=num_heads)
        self.decoder = nn.TransformerDecoder(self.decoder_layer, num_layers=num_layers)

        self.lm_head = nn.Linear(d_model, self.t5_model.config.vocab_size)

    def forward(self, input_ids, attention_mask, target_ids):
        input_embeddings = self.embedding(input_ids).to(input_ids.device)
        target_embeddings = self.embedding(target_ids).to(target_ids.device)
    
        input_embeddings = input_embeddings.permute(1, 0, 2)
        target_embeddings = target_embeddings.permute(1, 0, 2)
    
        encoder_outputs = self.encoder(input_embeddings)
    
        decoder_outputs = self.decoder(target_embeddings, encoder_outputs)
    
        decoder_outputs = decoder_outputs.permute(1, 0, 2) 
    
        logits = self.lm_head(decoder_outputs)
    
        return logits


In [None]:
def load_and_tokenize_data(tokenizer, max_input_length=512, max_target_length=128):
    dataset = load_dataset("cnn_dailymail", "3.0.0")

    def preprocess_function(examples):
        inputs = tokenizer(examples["article"], padding="max_length", truncation=True, max_length=max_input_length)
        targets = tokenizer(examples["highlights"], padding="max_length", truncation=True, max_length=max_target_length)

        return {
            "input_ids": inputs["input_ids"],
            "attention_mask": inputs["attention_mask"],
            "target_ids": targets["input_ids"],
        }

    tokenized_dataset = dataset.map(preprocess_function, batched=True)
    return tokenized_dataset

tokenizer = T5Tokenizer.from_pretrained("t5-small")

tokenized_dataset = load_and_tokenize_data(tokenizer)

class TextDataset(Dataset):
    def __init__(self, dataset):
        self.input_ids = dataset["input_ids"]
        self.attention_masks = dataset["attention_mask"]
        self.target_ids = dataset["target_ids"]

    def __len__(self):
        return len(self.input_ids)

    def __getitem__(self, idx):
        return {
            "input_ids": torch.tensor(self.input_ids[idx], dtype=torch.long),
            "attention_mask": torch.tensor(self.attention_masks[idx], dtype=torch.long),
            "target_ids": torch.tensor(self.target_ids[idx], dtype=torch.long),
        }

train_dataset = TextDataset(tokenized_dataset["train"])
val_dataset = TextDataset(tokenized_dataset["validation"])

train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=4)

In [None]:
model = CustomT5EncoderDecoder()
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=5e-5)

epochs = 3

for epoch in range(epochs):
    model.train()
    total_loss = 0
    
    for batch in train_loader:
        input_ids = batch["input_ids"]
        attention_mask = batch["attention_mask"]
        target_ids = batch["target_ids"]

        optimizer.zero_grad()
        logits = model(input_ids, attention_mask, target_ids)

        loss = criterion(logits.view(-1, logits.size(-1)), target_ids.view(-1))
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
    
    avg_train_loss = total_loss / len(train_loader)
    print(f"Epoch {epoch+1}, Training Loss: {avg_train_loss}")

    model.eval()
    val_loss = 0

    with torch.no_grad():
        for batch in val_loader:
            input_ids = batch["input_ids"]
            attention_mask = batch["attention_mask"]
            target_ids = batch["target_ids"]

            logits = model(input_ids, attention_mask, target_ids)
            loss = criterion(logits.view(-1, logits.size(-1)), target_ids.view(-1))

            val_loss += loss.item()

    avg_val_loss = val_loss / len(val_loader)
    print(f"Epoch {epoch+1}, Validation Loss: {avg_val_loss}")
