<a href="https://colab.research.google.com/github/OneFineStarstuff/State-of-the-Art/blob/main/Transformers_and_Attention_Mechanisms.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
pip install torchtext

In [None]:
pip install datasets

In [None]:
import torch
import torch.nn as nn
from transformers import BertTokenizer, BertModel
from datasets import load_dataset
from torch.utils.data import DataLoader

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

dataset = load_dataset('imdb')
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

def encode_batch(batch):
    return tokenizer(batch['text'], padding=True, truncation=True, return_tensors='pt')

encoded_dataset = dataset.map(encode_batch, batched=True)
encoded_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask', 'label'])

train_dataset = encoded_dataset['train']
test_dataset = encoded_dataset['test']

train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=16)

class TransformerClassifier(nn.Module):
    def __init__(self):
        super(TransformerClassifier, self).__init__()
        self.bert = BertModel.from_pretrained('bert-base-uncased')
        self.fc = nn.Linear(self.bert.config.hidden_size, 1)
        self.sigmoid = nn.Sigmoid()

    def forward(self, input_ids, attention_mask):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        cls_output = outputs.last_hidden_state[:, 0, :]
        logits = self.fc(cls_output)
        return self.sigmoid(logits)

model = TransformerClassifier().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=2e-5)
criterion = nn.BCELoss().to(device)

def train(model, data_loader, optimizer, criterion, device):
    model.train()
    total_loss = 0
    for batch in data_loader:
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['label'].float().to(device)

        optimizer.zero_grad()
        outputs = model(input_ids, attention_mask).squeeze(-1)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

    return total_loss / len(data_loader)

for epoch in range(3):
    train_loss = train(model, train_loader, optimizer, criterion, device)
    print(f'Epoch {epoch+1}, Train Loss: {train_loss:.3f}')