<a href="https://colab.research.google.com/github/OneFineStarstuff/OneFineStarstuff/blob/main/LongformerClassificationModel(nn_Module).ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import AdamW
from transformers import LongformerModel

class LongformerClassificationModel(nn.Module):
    def __init__(self, model_name):
        super(LongformerClassificationModel, self).__init__()
        self.model = LongformerModel.from_pretrained(model_name)
        self.classifier = nn.Linear(self.model.config.hidden_size, 2)  # Adjust the output size as per your requirement

    def forward(self, input_ids, attention_mask):
        outputs = self.model(input_ids, attention_mask)
        logits = self.classifier(outputs.last_hidden_state[:, 0, :])  # Use the [CLS] token representation for classification
        return logits

    def clone_parameters(self):
        return {name: param.clone() for name, param in self.named_parameters()}

    def load_parameters(self, params):
        for name, param in self.named_parameters():
            param.data.copy_(params[name].data)

    def fast_adapt(self, support_data, query_data, optimizer, n_steps=5, lr_inner=1e-3):
        support_input, support_attention, support_target = support_data
        query_input, query_attention, query_target = query_data

        original_params = self.clone_parameters()
        for _ in range(n_steps):
            optimizer.zero_grad()
            logits = self(support_input, support_attention)
            loss = F.cross_entropy(logits, support_target)
            loss.backward()
            optimizer.step()

        query_loss = F.cross_entropy(self(query_input, query_attention), query_target)
        self.load_parameters(original_params)
        return query_loss

# Example usage
model_name = "allenai/longformer-base-4096"
model = LongformerClassificationModel(model_name)
optimizer = AdamW(model.parameters(), lr=1e-5)

# Dummy data for example
support_input = torch.randint(0, 100, (8, 512))
support_attention = torch.ones_like(support_input)
support_target = torch.randint(0, 2, (8,))

query_input = torch.randint(0, 100, (8, 512))
query_attention = torch.ones_like(query_input)
query_target = torch.randint(0, 2, (8,))

support_data = (support_input, support_attention, support_target)
query_data = (query_input, query_attention, query_target)

query_loss = model.fast_adapt(support_data, query_data, optimizer)
print(f"Query Loss: {query_loss.item()}")