In [38]:
from transformers import T5Tokenizer
from transformers import T5ForConditionalGeneration, T5Tokenizer
from dataset_utils import T5Dataset, collator
import torch
from torch.utils.data import DataLoader
import os

model_variants = ["google-t5/t5-base", "google-t5/t5-3b", "google/flan-t5-large", "google-t5/t5-11b"]
device = 'cuda:0'

model = T5ForConditionalGeneration.from_pretrained(
    model_variants[0],
    torch_dtype=torch.float16,
    token = os.getenv('HF_ACCESS_TOKEN')
).to(device)

tokenizer = T5Tokenizer.from_pretrained(
    model_variants[0]
)



generation_config.json:   0%|          | 0.00/147 [00:00<?, ?B/s]

In [31]:
from torch.utils.data import Dataset
import torch.nn.utils
import numpy as np

class T5Dataset(Dataset):
    def __init__(
        self,
        texts,
        tokenizer,
        max_length=512,
        corruption_rate=0.15,
        mean_span_length=3
    ):
        # super().__init__()
        self.texts = texts
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.corruption_rate = corruption_rate
        self.mean_span_length = mean_span_length

    def corrupt_text(self, tokens):
        # tokens = text.strip().split()
        mask = self.generate_span_mask(len(tokens))
        input_tokens = []
        target_tokens = []
        sentinel = 0
        in_span = False
        for i, token in enumerate(tokens):
            if mask[i]:
                if not in_span:
                    input_tokens.append(f"<extra_id_{sentinel}>")
                    target_tokens.append(f"<extra_id_{sentinel}>")
                    sentinel += 1
                    in_span = True
                target_tokens.append(token)
            else:
                input_tokens.append(token)
                in_span = False

        if in_span:
            target_tokens.append(f"<extra_id_{sentinel+1}>")
        else:
            input_tokens.append(f"<extra_id_{sentinel}>")
            target_tokens.append(f"<extra_id_{sentinel}>")
        
        return " ".join(input_tokens), " ".join(target_tokens)

    def generate_span_mask(self, seq_len):
        num_tokens_to_mask = max(1, int(self.corruption_rate * seq_len))
        mask = np.zeros(seq_len, dtype=bool)
        num_masked = 0
        while num_masked < num_tokens_to_mask:
            span_start = np.random.randint(0, seq_len)
            span_length = max(1, np.random.poisson(self.mean_span_length))
            span_end = min(seq_len, span_start + span_length)
            if np.any(mask[span_start:span_end]):
                continue
            mask[span_start:span_end] = True
            num_masked += span_end - span_start
        return mask

    def __getitem__(self, idx):
        text = self.texts[idx]
        tokens = text.strip().split()
        
        # Handle empty text case first
        if len(tokens) == 0:
            input_ids = self.tokenizer.encode("", return_tensors="pt").squeeze(0)
            target_ids = self.tokenizer.encode("", return_tensors="pt").squeeze(0)
            return {"input_ids": input_ids, "labels": target_ids}
        
        try:
            corrupted_input, target = self.corrupt_text(tokens)
            input_ids = self.tokenizer.encode(
                corrupted_input, 
                truncation=True, 
                max_length=self.max_length, 
                return_tensors="pt"
            ).squeeze(0)
            
            target_ids = self.tokenizer.encode(
                target, 
                truncation=True, 
                max_length=self.max_length, 
                return_tensors="pt"
            ).squeeze(0)
            
            return {"input_ids": input_ids, "labels": target_ids}
        except Exception as e:
            print(f"Error processing item {idx}: {e}")
            # Fallback to empty tokens
            input_ids = self.tokenizer.encode("", return_tensors="pt").squeeze(0)
            target_ids = self.tokenizer.encode("", return_tensors="pt").squeeze(0)
            return {"input_ids": input_ids, "labels": target_ids}

    def __len__(self):
        return len(self.texts)


def collator(batch, tokenizer):
    try:
        input_ids = [item['input_ids'] for item in batch]
        labels = [item['labels'] for item in batch]
        input_ids = torch.nn.utils.rnn.pad_sequence(input_ids, batch_first=True, padding_value=tokenizer.pad_token_id)
        labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=-100)
    except:
        print(batch)
    return {"input_ids": input_ids, "labels": labels, "attention_mask": input_ids.ne(tokenizer.pad_token_id)}

In [32]:
def my_collator(batch):
    return collator(batch, tokenizer)

raw_input = '/home/tadesa1/ADBMO-UNLV/data/processed_output_raw.txt'
with open(raw_input, 'r') as f:
    text = f.readlines()
    text = [line for line in text if line.strip()]

dataset = T5Dataset(text, tokenizer)
dataloader = DataLoader(dataset, batch_size=2, shuffle=True, collate_fn=my_collator)


In [None]:
for batch in dataloader:
    print(batch["input_ids"].shape)       # (batch_size, seq_len)
    print(batch["labels"].shape)
    print(batch["input_ids"])
    print(batch["labels"])
    print("-----------------")
    # break

In [36]:
import torch
from transformers import T5ForConditionalGeneration, AdamW
from tqdm import tqdm

def train_t5_unsupervised(
    model,
    dataloader,
    optimizer=None,
    device="cuda" if torch.cuda.is_available() else "cpu",
    num_epochs=3,
    accumulation_steps=1,
    save_path="t5_finetuned.pt"
):
    model = model.to(device)
    model.train()

    if optimizer is None:
        optimizer = AdamW(model.parameters(), lr=5e-5)

    for epoch in range(num_epochs):
        total_loss = 0.0
        loop = tqdm(dataloader, desc=f"Epoch {epoch+1}/{num_epochs}", leave=False)

        for step, batch in enumerate(loop):
            # Move batch to device
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            labels = batch["labels"].to(device)

            # Forward pass
            outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
            loss = outputs.loss
            loss = loss / accumulation_steps
            loss.backward()

            # Optimizer step
            if (step + 1) % accumulation_steps == 0:
                optimizer.step()
                optimizer.zero_grad()

            total_loss += loss.item() * accumulation_steps
            loop.set_postfix(loss=loss.item())

        avg_loss = total_loss / len(dataloader)
        print(f"Epoch {epoch+1} Loss: {avg_loss:.4f}")

        # Save checkpoint
        torch.save(model.state_dict(), f"{save_path}_epoch{epoch+1}.pt")


In [None]:
train_t5_unsupervised(model, dataloader, collator)