In [1]:
import numpy as np
import random
import torch
import torch.nn as nn
import torch.optim as optim
from torch_geometric.nn import GCNConv
from transformers import T5Tokenizer, T5ForConditionalGeneration
import spacy
from torch.nn import CrossEntropyLoss
from datasets import load_dataset

seed = 42

random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(seed)

In [2]:
# Load the e-SNLI dataset
dataset = load_dataset("esnli")

train_dataset = dataset['train']
eval_dataset = dataset['validation']
#test_dataset = dataset['test']

indices = list(range(0, len(train_dataset), 10))  # Select every 10th index
train_dataset = train_dataset.select(indices)

len(train_dataset), len(eval_dataset)#, len(test_dataset)

Reusing dataset esnli (/home/ec2-user/.cache/huggingface/datasets/esnli/plain_text/0.0.2/a160e6a02bbb8d828c738918dafec4e7d298782c334b5109af632fec6d779bbc)


  0%|          | 0/3 [00:00<?, ?it/s]

(54937, 9842)

In [3]:
label_dct = {0: "entailment", 1: "neutral", 2: "contradiction"}

In [4]:
checkpoint_path = "../expt1/flan_t5_esnli/checkpoint-" + str(18887)
tokenizer = T5Tokenizer.from_pretrained(checkpoint_path)

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [5]:
# Load SpaCy for dependency parsing
nlp = spacy.load("en_core_web_sm")

# Generate Graph Features Using SpaCy
def get_dependency_graph(text):
    # Parse the text with SpaCy
    doc = nlp(text)
    nodes = [tokenizer(token.text, truncation=True, padding="max_length", max_length=512, return_tensors="pt")["input_ids"][0] for token in doc]
    edges = [[(token.head.i) for token in doc if token.dep_ != 'punct'],
             [(token.i) for token in doc if token.dep_ != 'punct']]
    return nodes, edges

In [6]:
# Preprocessing function
def preprocess(example):
    # Prepare input and output text
    input_text = f"Premise: {example['premise']} Hypothesis: {example['hypothesis']} What is the relationship? Explain your answer."
    output_text = f"{label_dct[example['label']]}: {example['explanation_1']}. {example['explanation_2']}. {example['explanation_3']}."

    # Tokenize input and output
    input_encoding = tokenizer(input_text, truncation=True, padding="max_length", max_length=512, return_tensors="pt")
    output_encoding = tokenizer(output_text, truncation=True, padding="max_length", max_length=512, return_tensors="pt")
    
    # Generate graphs for both premise and hypothesis
    combined_nodes, combined_edges = get_dependency_graph(f"Premise: {example['premise']} Hypothesis: {example['hypothesis']}")
    combined_nodes = torch.stack(combined_nodes)
    combined_nodes = nn.functional.pad(combined_nodes, (0, 0, 0, 120-combined_nodes.shape[0]))
    combined_edges = torch.tensor(combined_edges)
    combined_edges = nn.functional.pad(combined_edges, (0, 120-combined_edges.shape[1]), value=-1)

    # Create a dictionary to return
    return {
        "input_ids": input_encoding["input_ids"][0],  # Remove batch dimension
        "attention_mask": input_encoding["attention_mask"][0],  # Remove batch dimension
        "labels": output_encoding["input_ids"][0], # Remove batch dimension
        "combined_nodes": combined_nodes,
        "combined_edges": combined_edges,
    }

# Apply preprocessing
train_dataset = train_dataset.map(
    preprocess,
    remove_columns=['premise', 'hypothesis', 'label', 'explanation_1', 'explanation_2', 'explanation_3'],
    load_from_cache_file=False
)
eval_dataset = eval_dataset.map(
    preprocess,
    remove_columns=['premise', 'hypothesis', 'label', 'explanation_1', 'explanation_2', 'explanation_3'],
    load_from_cache_file=False
)

  0%|          | 0/54937 [00:00<?, ?ex/s]

  0%|          | 0/9842 [00:00<?, ?ex/s]

In [7]:
train_dataset.set_format(type="torch", columns=["input_ids", "attention_mask", "labels", "combined_nodes", "combined_edges"])
eval_dataset.set_format(type="torch", columns=["input_ids", "attention_mask", "labels", "combined_nodes", "combined_edges"])
print(train_dataset[0]['input_ids'].shape)  # Should show (512,)
print(train_dataset[0]['attention_mask'].shape)  # Should show (512,)
print(train_dataset[0]['labels'].shape)  # Should show (512,)

torch.Size([512])
torch.Size([512])
torch.Size([512])


In [8]:
# define the GCN model that processes graph data and outputs embeddings
class GCN(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(GCN, self).__init__()
        self.conv1 = GCNConv(in_channels, 64)
        self.conv2 = GCNConv(64, out_channels)

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index)
        x = torch.relu(x)
        x = self.conv2(x, edge_index)
        return x

In [9]:
class T5WithUnifiedGCN(nn.Module):
    def __init__(self, t5_model, gcn_model):
        super(T5WithUnifiedGCN, self).__init__()
        self.t5 = t5_model
        self.gcn = gcn_model

    def forward(self, input_ids, attention_mask, combined_nodes, combined_edges, labels):
        # Step 1: Get embeddings from the T5 encoder (for the combined input)
        encoder_outputs = self.t5.encoder(input_ids=input_ids, attention_mask=attention_mask)
        t5_embeddings = encoder_outputs.last_hidden_state  # Shape: (batch_size, sequence_length, embedding_dim)

        # Step 2: Generate a graph embedding from the GCN using the combined sequence
        #combined_edges = torch.stack(combined_edges)
        combined_edges = combined_edges[combined_edges!=-1].reshape((2,-1))
        #combined_nodes = torch.stack(combined_nodes)
        zero_rows = (combined_nodes==0).all(dim=1)
        combined_nodes = combined_nodes[~zero_rows]
        graph_embeddings = self.gcn(combined_nodes.to(torch.float), combined_edges.to(torch.long)).unsqueeze(0)
        
        # Step 4: Concatenate the T5 embeddings with the graph embeddings
        combined_embeddings = torch.cat([t5_embeddings, graph_embeddings], dim=1)
        
        # Step 5: Pass the combined embeddings to the T5 decoder (for sequence generation or classification)
        decoder_outputs = self.t5.decoder(input_ids=input_ids, encoder_hidden_states=combined_embeddings, attention_mask=attention_mask)
        sequence_output = decoder_outputs[0]

        if self.t5.config.tie_word_embeddings:
            # Rescale output before projecting on vocab
            # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586
            sequence_output = sequence_output * (self.model_dim**-0.5)

        lm_logits = self.t5.lm_head(sequence_output)

        """
        loss = None
        if labels is not None:
            loss_fct = CrossEntropyLoss(ignore_index=-100)
            # move labels to correct device to enable PP
            labels = labels.to(lm_logits.device)
            loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1))
        """

        output = (lm_logits,) + decoder_outputs[1:] + encoder_outputs[0:]
        return output

In [10]:
def generate_output_sequences(model, input_ids, attention_mask, combined_nodes, combined_edges, tokenizer, max_length=50, num_beams=5):
    """
    Generates sequences from the T5WithUnifiedGCN model, which includes the T5 model and Graph Convolution Network.

    Args:
        model: T5WithUnifiedGCN model
        input_ids: Tensor of input token IDs (batch_size, seq_length)
        attention_mask: Attention mask for the input (batch_size, seq_length)
        combined_nodes: The combined graph nodes (batch_size, num_nodes)
        combined_edges: The combined graph edges (batch_size, num_edges)
        tokenizer: The T5 tokenizer
        max_length: The maximum length of the generated sequences
        num_beams: Number of beams for beam search (for controlled generation)

    Returns:
        Generated sequences (List of strings)
    """
    
    # Ensure the model is in evaluation mode
    model.eval()

    # Pass inputs through the model
    with torch.no_grad():
        # Get the logits and other outputs from the forward pass
        outputs = model(input_ids=input_ids, 
                        attention_mask=attention_mask, 
                        combined_nodes=combined_nodes, 
                        combined_edges=combined_edges, 
                        labels=None)  # We do not need labels during inference
        
        lm_logits = outputs[0]  # (batch_size, seq_len, vocab_size)

    # Decode the output sequences
    # Get the token probabilities and use the argmax or sample from the distribution
    # You can use beam search or greedy decoding for generating sequences

    generated_sequences = []
    for batch_idx in range(lm_logits.size(0)):  # Iterate over each batch
        # Use beam search or greedy decoding (beam search is usually more sophisticated)
        # Using beam search here for better results
        generated_ids = model.t5.generate(
            input_ids=input_ids[batch_idx:batch_idx+1],  # Process each sample individually
            attention_mask=attention_mask[batch_idx:batch_idx+1],
            max_length=max_length,
            num_beams=num_beams,
            early_stopping=True
        )
        
        # Decode the generated token IDs to strings
        decoded_sequence = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
        generated_sequences.append(decoded_sequence)

    return generated_sequences

In [12]:
# Initialize model
t5_model = T5ForConditionalGeneration.from_pretrained(checkpoint_path)
gcn_model = GCN(in_channels=512, out_channels=512)

# Create your model
model = T5WithUnifiedGCN(t5_model, gcn_model)

In [None]:
#outed = model(train_dataset[0]["input_ids"].unsqueeze(0), train_dataset[0]["attention_mask"].unsqueeze(0),
#              train_dataset[0]["combined_nodes"], train_dataset[0]["combined_edges"], train_dataset[0]["labels"].unsqueeze(0))

In [20]:
generate_output_sequences(model, train_dataset[1]["input_ids"].unsqueeze(0), train_dataset[1]["attention_mask"].unsqueeze(0),
              torch.stack(train_dataset[1]["combined_nodes"]), torch.stack(train_dataset[1]["combined_edges"]), tokenizer)

['']

In [15]:
import torch
from torch.optim import AdamW
from torch.nn import CrossEntropyLoss
from tqdm import tqdm


learning_rate = 0.001
optimizer = AdamW(
    model.parameters(),
    lr=learning_rate,
    betas=(0.9, 0.999),
    eps=1e-08,
)
loss_fct = CrossEntropyLoss(ignore_index=-100)  # For sequence generation

# Move the model to the correct device
device = 'cpu'
model.to(device)

# Train the model manually without DataLoader
def train_without_dataloader(model, dataset, optimizer, loss_fct, num_epochs=3):
    model.train()  # Set the model to training mode
    for epoch in range(num_epochs):
        total_loss = 0
        for idx in tqdm(range(len(dataset)), desc=f"Training Epoch {epoch + 1}/{num_epochs}"):
            # Get the data point from the dataset
            batch = dataset[idx]  # Get item at index idx

            # Move tensors to the correct device
            input_ids = batch['input_ids'].to(device).unsqueeze(0)
            attention_mask = batch['attention_mask'].to(device).unsqueeze(0)
            labels = batch['labels'].to(device).unsqueeze(0)
            combined_nodes = torch.stack(batch['combined_nodes']).to(device)
            combined_edges = torch.stack(batch['combined_edges']).to(device)

            # Zero the gradients before each pass
            optimizer.zero_grad()

            # Forward pass through the model
            outputs = model(input_ids=input_ids, 
                            attention_mask=attention_mask, 
                            combined_nodes=combined_nodes, 
                            combined_edges=combined_edges, 
                            labels=labels)

            # Get the logits
            lm_logits = outputs[0]
            
            # Calculate the loss
            loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1))

            # Backpropagation and optimizer step
            loss.backward()
            optimizer.step()

            # Accumulate the loss for reporting
            total_loss += loss.item()

        # Print the average loss for the epoch
        avg_loss = total_loss / len(dataset)
        print(f"Epoch {epoch + 1}, Loss: {avg_loss}")

# Train the model
train_without_dataloader(model, train_dataset, optimizer, loss_fct, num_epochs=1)

Training Epoch 1/1:  66%|████████████████████████████████████████████████████████████████████████████████████████████▏                                              | 36452/54937 [13:12:29<6:41:52,  1.30s/it]


KeyboardInterrupt: 

In [18]:
model

T5WithUnifiedGCN(
  (t5): T5ForConditionalGeneration(
    (shared): Embedding(32128, 512)
    (encoder): T5Stack(
      (embed_tokens): Embedding(32128, 512)
      (block): ModuleList(
        (0): T5Block(
          (layer): ModuleList(
            (0): T5LayerSelfAttention(
              (SelfAttention): T5Attention(
                (q): Linear(in_features=512, out_features=384, bias=False)
                (k): Linear(in_features=512, out_features=384, bias=False)
                (v): Linear(in_features=512, out_features=384, bias=False)
                (o): Linear(in_features=384, out_features=512, bias=False)
                (relative_attention_bias): Embedding(32, 6)
              )
              (layer_norm): T5LayerNorm()
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (1): T5LayerFF(
              (DenseReluDense): T5DenseGatedActDense(
                (wi_0): Linear(in_features=512, out_features=1024, bias=False)
                (wi_1): Linear(