In [6]:
import torch
from torch.utils.data import Dataset, DataLoader
from datasets import load_dataset
from transformers import BertTokenizer, BertModel
import random
import numpy as np
import time
import torch.nn as nn
import torch.optim as optim

# Set random seeds
SEED = 1234
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.backends.cudnn.deterministic = True

# Load tokenizer and model
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
bert = BertModel.from_pretrained('bert-base-uncased')
MAX_LEN = 512  # BERT's maximum sequence length

# Load and split dataset
imdb = load_dataset('imdb')
train_valid = imdb['train'].train_test_split(test_size=0.2, seed=SEED)
train_data = train_valid['train']
valid_data = train_valid['test']
test_data = imdb['test']

class BERTDataset(Dataset):
    def __init__(self, data, tokenizer, max_len):
        self.data = data
        self.tokenizer = tokenizer
        self.max_len = max_len
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        text = self.data[idx]['text']
        label = self.data[idx]['label']
        encoding = self.tokenizer.encode_plus(
            text,
            add_special_tokens=True,
            max_length=self.max_len,
            truncation=True,
            padding='max_length',
            return_tensors='pt'
        )
        return {
            'input_ids': encoding['input_ids'].flatten(),
            'attention_mask': encoding['attention_mask'].flatten(),
            'label': torch.tensor(label, dtype=torch.float)
        }

def collate_fn(batch):
    return {
        'input_ids': torch.stack([x['input_ids'] for x in batch]),
        'attention_mask': torch.stack([x['attention_mask'] for x in batch]),
        'label': torch.stack([x['label'] for x in batch])
    }

# Create datasets and dataloaders
train_dataset = BERTDataset(train_data, tokenizer, MAX_LEN)
valid_dataset = BERTDataset(valid_data, tokenizer, MAX_LEN)
test_dataset = BERTDataset(test_data, tokenizer, MAX_LEN)

BATCH_SIZE = 16
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_fn)
valid_loader = DataLoader(valid_dataset, batch_size=BATCH_SIZE, collate_fn=collate_fn)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, collate_fn=collate_fn)

# Model definition
class BERTGRUSentiment(nn.Module):
    def __init__(self, bert, hidden_dim, output_dim, n_layers, bidirectional, dropout):
        super().__init__()
        self.bert = bert
        self.rnn = nn.GRU(bert.config.hidden_size,
                         hidden_dim,
                         num_layers=n_layers,
                         bidirectional=bidirectional,
                         batch_first=True,
                         dropout=0 if n_layers < 2 else dropout)
        self.out = nn.Linear(hidden_dim * 2 if bidirectional else hidden_dim, output_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, input_ids, attention_mask):
        with torch.no_grad():
            embedded = self.bert(input_ids=input_ids, attention_mask=attention_mask)[0]
        _, hidden = self.rnn(embedded)
        if self.rnn.bidirectional:
            hidden = self.dropout(torch.cat((hidden[-2], hidden[-1]), dim=1))
        else:
            hidden = self.dropout(hidden[-1])
        return self.out(hidden)

# Initialize model
model = BERTGRUSentiment(bert,
                        hidden_dim=256,
                        output_dim=1,
                        n_layers=2,
                        bidirectional=True,
                        dropout=0.25)

# Freeze BERT parameters
for name, param in model.named_parameters():
    if 'bert' in name:
        param.requires_grad = False

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)
optimizer = optim.Adam(model.parameters())
criterion = nn.BCEWithLogitsLoss()

# Training functions
def train_epoch(model, data_loader, optimizer, device):
    model.train()
    total_loss = 0
    correct_preds = 0
    
    for batch in data_loader:
        optimizer.zero_grad()
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['label'].to(device)
        
        outputs = model(input_ids, attention_mask).squeeze()
        loss = criterion(outputs, labels)
        
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
        preds = torch.round(torch.sigmoid(outputs))
        correct_preds += torch.sum(preds == labels).item()
    
    return total_loss / len(data_loader), correct_preds / len(data_loader.dataset)

def eval_model(model, data_loader, device):
    model.eval()
    total_loss = 0
    correct_preds = 0
    
    with torch.no_grad():
        for batch in data_loader:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['label'].to(device)
            
            outputs = model(input_ids, attention_mask).squeeze()
            loss = criterion(outputs, labels)
            
            total_loss += loss.item()
            preds = torch.round(torch.sigmoid(outputs))
            correct_preds += torch.sum(preds == labels).item()
    
    return total_loss / len(data_loader), correct_preds / len(data_loader.dataset)

# Training loop
N_EPOCHS = 3
best_valid_loss = float('inf')

for epoch in range(N_EPOCHS):
    start_time = time.time()
    
    train_loss, train_acc = train_epoch(model, train_loader, optimizer, device)
    valid_loss, valid_acc = eval_model(model, valid_loader, device)
    
    end_time = time.time()
    epoch_mins, epoch_secs = divmod(end_time - start_time, 60)
    
    if valid_loss < best_valid_loss:
        best_valid_loss = valid_loss
        torch.save(model.state_dict(), 'bert-gru-model.pt')
    
    print(f'Epoch {epoch+1}')
    print(f'Train Loss: {train_loss:.3f} | Acc: {train_acc*100:.2f}%')
    print(f'Valid Loss: {valid_loss:.3f} | Acc: {valid_acc*100:.2f}%')
    print(f'Time: {epoch_mins}m {epoch_secs}s\n')

# Final evaluation
model.load_state_dict(torch.load('bert-gru-model.pt'))
test_loss, test_acc = eval_model(model, test_loader, device)
print(f'Test Loss: {test_loss:.3f} | Acc: {test_acc*100:.2f}%')

# Prediction function
def predict_sentiment(text):
    model.eval()
    encoding = tokenizer.encode_plus(
        text,
        max_length=MAX_LEN,
        add_special_tokens=True,
        truncation=True,
        padding='max_length',
        return_tensors='pt'
    )
    input_ids = encoding['input_ids'].to(device)
    attention_mask = encoding['attention_mask'].to(device)
    
    with torch.no_grad():
        output = model(input_ids, attention_mask)
        proba = torch.sigmoid(output).item()
    
    return 'Positive' if proba > 0.5 else 'Negative', proba

# Example usage
print(predict_sentiment("This movie was an amazing experience!"))  # Positive
print(predict_sentiment("Terrible plot and bad acting."))          # Negative

KeyboardInterrupt: 