In [1]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from transformers import BertTokenizer, BertModel
import pandas as pd
import json

# Check for a GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


In [2]:
class BertClassifier(nn.Module):
    def __init__(self, num_classes=12, freeze_bert=True):
        super(BertClassifier, self).__init__()
        self.bert = BertModel.from_pretrained('bert-base-uncased')
        self.fc = nn.Linear(768, num_classes)

        if freeze_bert:
            for param in self.bert.parameters():
                param.requires_grad = False

    def forward(self, input_ids, attention_mask):
        outputs = self.bert(input_ids, attention_mask=attention_mask)
        last_hidden_state = outputs.last_hidden_state
        pooled_output = torch.mean(last_hidden_state, 1)
        return self.fc(pooled_output)


In [3]:
class TextDataset(Dataset):
    def __init__(self, texts, labels=None, tokenizer=None, max_len=512):
        self.tokenizer = tokenizer
        self.texts = texts
        self.labels = labels
        self.max_len = max_len

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        text = self.texts[idx]
        inputs = self.tokenizer.encode_plus(
            text,
            None,
            add_special_tokens=True,
            max_length=self.max_len,
            padding='max_length',
            return_token_type_ids=True,
            truncation=True
        )
        ids = inputs['input_ids']
        mask = inputs['attention_mask']

        if self.labels is not None:
            label = self.labels[idx]
            return {
                'input_ids': torch.tensor(ids, dtype=torch.long),
                'attention_mask': torch.tensor(mask, dtype=torch.long),
                'labels': torch.tensor(label, dtype=torch.long)
            }
        return {
            'input_ids': torch.tensor(ids, dtype=torch.long),
            'attention_mask': torch.tensor(mask, dtype=torch.long)
        }


In [4]:
def load_data(path):
    with open(path, 'r') as file:
        data = json.load(file)
    texts = []
    labels = []
    label_dict = {k: i for i, k in enumerate(data.keys())}
    for label, sentences in data.items():
        for sentence in sentences:
            texts.append(sentence)
            labels.append(label_dict[label])
    return texts, labels, label_dict

def load_test_data(path):
    with open(path, 'r') as file:
        texts = file.readlines()  # Read all lines in the file
    texts = [line.strip() for line in texts]  # Remove any extra whitespace
    return texts

train_texts, train_labels, label_dict = load_data('C:/Users/Maamar/Desktop/CS_ANLP_KaggComp/Data/augmented_train.json')
test_texts = load_test_data('C:/Users/Maamar/Desktop/CS_ANLP_KaggComp/Data/test_shuffle.txt')

tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
train_dataset = TextDataset(train_texts, train_labels, tokenizer)
test_dataset = TextDataset(test_texts, None, tokenizer)

train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False)


In [5]:
model = BertClassifier(num_classes=len(label_dict)).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=2e-5)
loss_fn = nn.CrossEntropyLoss()

def train_epoch(model, data_loader, loss_fn, optimizer, device):
    model = model.train()
    losses = []
    correct_predictions = 0

    for d in data_loader:
        input_ids = d["input_ids"].to(device)
        attention_mask = d["attention_mask"].to(device)
        labels = d["labels"].to(device)

        outputs = model(input_ids, attention_mask)
        _, preds = torch.max(outputs, dim=1)
        loss = loss_fn(outputs, labels)

        correct_predictions += torch.sum(preds == labels)
        losses.append(loss.item())

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    return correct_predictions.double() / len(data_loader.dataset), sum(losses) / len(losses)

for epoch in range(3):  # Number of epochs
    print(f"Epoch {epoch + 1}")
    train_acc, train_loss = train_epoch(model, train_loader, loss_fn, optimizer, device)
    print(f"Train loss {train_loss}, accuracy {train_acc}")


Epoch 1


KeyboardInterrupt: 