In [1]:
# Imports
import os
import pandas as pd
from transformers import BertTokenizer, BertModel
import torch.optim as optim
import torch
from torch.utils.data import Dataset, DataLoader
from torch import nn
import matplotlib.pyplot as plt
# import cupy

In [2]:
# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
torch.cuda.empty_cache()
# cupy.cuda.runtime.deviceReset()
torch.cuda.reset_max_memory_cached()
device



device(type='cuda')

In [3]:
# Train set
train_pos_folder = './data/aclImdb/train/pos'
train_neg_folder = './data/aclImdb/train/neg'
# train_pos_folder = './data/aclImdb/test/pos'
# train_neg_folder = './data/aclImdb/test/neg'

train_pos_sentences = [open(os.path.join(train_pos_folder, f)).read().strip() for f in os.listdir(train_pos_folder)]
train_neg_sentences = [open(os.path.join(train_neg_folder, f)).read().strip() for f in os.listdir(train_neg_folder)]

train_df = pd.DataFrame({
    'text': train_pos_sentences + train_neg_sentences,
    'label': [1] * len(train_pos_sentences) + [0] * len(train_neg_sentences)  # 1 for positive, 0 for negative
})

train_df.head()

Unnamed: 0,text,label
0,Bromwell High is a cartoon comedy. It ran at t...,1
1,Homelessness (or Houselessness as George Carli...,1
2,Brilliant over-acting by Lesley Ann Warren. Be...,1
3,This is easily the most underrated film inn th...,1
4,This is not the typical Mel Brooks film. It wa...,1


In [4]:
# Check for NaN values in columns
print(train_df['text'].isna().sum())
print(train_df['label'].isna().sum())

# Check labels
unique_values = train_df['label'].unique()
print(unique_values)

# Check max length
max_words = train_df['text'].apply(lambda x: len(x.split())).max()
max_words

0
0
[1 0]


2470

In [5]:
# Dataset
class IMDBDataset(Dataset):
    def __init__(self, data, tokenizer, max_len):
        self.data = data
        self.tokenizer = tokenizer
        self.max_len = max_len

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        text = self.data.iloc[idx, 0]
        label = self.data.iloc[idx, 1]

        encoding = self.tokenizer.encode_plus(
            text,
            add_special_tokens=True,
            max_length=self.max_len,
            padding='max_length',
            truncation=True,
            return_attention_mask=True,
            return_tensors='pt')    # as pytorch tensors

        return {
            'input_ids': encoding['input_ids'].flatten(),
            'attention_mask': encoding['attention_mask'].flatten(),
            'labels': torch.tensor([label], dtype=torch.float)
        }

In [6]:
# Make torch DataLoader
batch_size = 32

# Initialize the BERT tokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

# Create data loaders
train_dataset = IMDBDataset(train_df, tokenizer, max_len=512)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)



In [7]:
# Check tensors
stopper = 0
for batch in train_loader:
    if stopper == 30:
        break
    input_ids = batch['input_ids']
    attention_mask = batch['attention_mask']
    labels = batch['labels']
    print(input_ids.shape, attention_mask.shape, labels.shape)
    stopper += 1

torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32, 1])
torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32, 1])
torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32, 1])
torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32, 1])
torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32, 1])
torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32, 1])
torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32, 1])
torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32, 1])
torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32, 1])
torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32, 1])
torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32, 1])
torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32, 1])
torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32, 1])
torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32, 1])
torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32, 1])
torch.Size([32, 512]) torch.Size([32, 51

In [8]:
# Define hyperparmaters
num_epochs = 3
learning_rate = 1e-5

In [9]:
# Define the BertForSentenceClassification model
class BertForSentenceClassification(nn.Module):
    def __init__(self, bert_model):
        super(BertForSentenceClassification, self).__init__()
        self.bert_model = bert_model
        self.dropout = nn.Dropout(0.1)
        self.classifier = nn.Linear(self.bert_model.config.hidden_size, 1)

    def forward(self, input_ids, attention_mask):
        outputs = self.bert_model(input_ids, attention_mask=attention_mask)
        pooled_output = outputs.pooler_output
        pooled_output = self.dropout(pooled_output)
        outputs = self.classifier(pooled_output)
        return outputs

In [10]:
# Apply and check the model

# Load pre-trained BERT model
bert_model = BertModel.from_pretrained('bert-base-uncased')

# Create an instance of the custom model
model = BertForSentenceClassification(bert_model)
model.to(device=device)

BertForSentenceClassification(
  (bert_model): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSdpaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), 

In [11]:
# Basic check
input_ids = torch.randint(0, 100, (32, 512)).to(device)  # random input IDs
attention_mask = torch.ones((32, 512)).to(device)  # random attention mask
output = model(input_ids, attention_mask)
print(output.shape)

  attn_output = torch.nn.functional.scaled_dot_product_attention(


torch.Size([32, 1])


In [12]:
# Define the optimizer and loss function
criterion = nn.BCELoss()

# # Define the optimizer
optimizer = optim.Adam(model.parameters(), lr=1e-5)

In [13]:
i = 0
losses = []
for epoch in range(1):
    model.train()
    total_loss = 0
    for batch in train_loader:
        print(f"batch: {i}", end="\r")
        
        input_ids = batch['input_ids'].to(device=device)
        attention_mask = batch['attention_mask'].to(device=device)
        labels = batch['labels'].to(device=device)

        optimizer.zero_grad()

        outputs = model(input_ids, attention_mask)
        loss = criterion(outputs, labels)

        loss.backward()
        optimizer.step()
        total_loss += loss.item()
        
        i += 1
        if i == 2:
            break

    # Append the current loss to the list
    losses.append(loss.item())

    print(f'Epoch {epoch+1}, Loss: {total_loss / len(train_loader)}')

Epoch 1, Loss: 0.0020624914437608645


In [None]:
model.eval()
test_loss = 0
test_metrics = []
with torch.no_grad():
    for batch in test_loader:
        input_ids = batch['input_ids'].to(device=device)
        attention_mask = batch['attention_mask'].to(device=device)
        labels = batch['labels'].to(device=device)

        outputs = model(input_ids, attention_mask)
        loss = criterion(outputs, labels)
        test_loss += loss.item()

        # Calculate metrics (e.g., accuracy, F1-score)
        _, preds = torch.max(outputs, dim=1)
        accuracy = (preds == labels).sum().item() / len(labels)
        test_metrics.append(accuracy)

test_loss /= len(test_loader)
test_metrics = np.mean(test_metrics)

print(f'Test Loss: {test_loss:.4f}')
print(f'Test Accuracy: {test_metrics:.4f}')