# Part 3: BERT-LSTM Model Building

In this notebook, we will use the tokenized Yelp Dataset and fine-tune the RoBERTa model for our ternary sentiment analysis task. RoBERTa is a larger, more robust version of the original BERT, having been trained on more than 160 GB of textual data, making it capable for our task.

## Installing Dependencies

In [None]:
! pip install -q gdown google-api-python-client google-auth google-auth-httplib2 google-auth-oauthlib

## Downloading tokenized datasets and their labels

In [None]:
file_ids = ['11JFdenZ7TMGib8Tt8mPfQITGtn28KcBM', 
            '1zLzSOV6498qSusNIXvuOQn4PRoe2be_a', 
            '1j4JnzhLushqU-AFd1KOn08_Kt_YC77YN',
            '1QtvHQO4TbSRqSRnuqL27pXTy5lMVdHsP',
            '1_T2tYQpBDr1pC6HNqVTNeJUtD3gS4KeN',
            '17A_cbIZP53Ho9itsF8Sz43IScLssin_U']

In [None]:
for file_id in file_ids:
    ! gdown {file_id}

## Importing Dependencies

In [None]:
from google.oauth2 import service_account
from googleapiclient.discovery import build
from googleapiclient.http import MediaFileUpload, MediaIoBaseDownload
import io
import re
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
from transformers import AutoModel, get_linear_schedule_with_warmup
from torch.amp import autocast, GradScaler
import time
from tqdm import tqdm
from torch.nn import DataParallel
from sklearn.metrics import precision_score, recall_score, f1_score
import os
import gdown
import sys

## Setting up credentials to save checkpoints

In [None]:
CREDENTIALS_FILE_ID = '10O9NYV9U8l6F3CSxZ4_5kmN9BQDc12FN'
CREDENTIALS_LOCAL = 'credentials.json'
CHECKPOINT_FOLDER_ID = '17zp1KQwNd3MKmnkd_-BWiTtscmnH6Sdm'

## Setting up GPU

In [None]:
device = torch.device("cuda")

## Loading Dataset

In [None]:
train_encodings = torch.load("train_encodings.pt")
val_encodings = torch.load("val_encodings.pt")
test_encodings = torch.load("test_encodings.pt")

In [None]:
train_labels = torch.load("train_labels.pt")
val_labels = torch.load("val_labels.pt")
test_labels = torch.load("test_labels.pt")

## Dataset Preparation

In [None]:
train_dataset = TensorDataset(train_encodings['input_ids'], train_encodings['attention_mask'], train_labels)
val_dataset = TensorDataset(val_encodings['input_ids'], val_encodings['attention_mask'], val_labels)
test_dataset = TensorDataset(test_encodings['input_ids'], test_encodings['attention_mask'], test_labels)

In [None]:
batch_size = 64
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, pin_memory=True)

## Defining Architecture

In [None]:
class RoBERTa_LSTM(nn.Module):
    def __init__(self, roberta_model='roberta-base', lstm_hidden=256, num_classes=3):
        super().__init__()
        self.roberta = AutoModel.from_pretrained(roberta_model)
        self.lstm = nn.LSTM(input_size=self.roberta.config.hidden_size,
                            hidden_size=lstm_hidden,
                            batch_first=True,
                            bidirectional=True)

        self.norm = nn.LayerNorm(lstm_hidden * 2)
        self.drop = nn.Dropout(0.4)
        self.fc = nn.Linear(lstm_hidden * 2, num_classes)

    def forward(self, input_ids, attention_mask):
        roberta_out = self.roberta(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state
        lstm_out, _ = self.lstm(roberta_out)
        pooled = torch.mean(lstm_out, dim=1)
        normed = self.norm(pooled)
        return self.fc(self.drop(normed))

## Model Initialization

In [None]:
model = RoBERTa_LSTM()
model = DataParallel(model)
model = model.to(device)
criterion = nn.CrossEntropyLoss(label_smoothing=0.1)
optimizer = optim.AdamW(model.parameters(), lr=2e-5, weight_decay=0.01)
num_training_steps = len(train_loader) * 5
num_warmup_steps = int(0.1 * num_training_steps)
scheduler = get_linear_schedule_with_warmup(optimizer,
                                            num_warmup_steps=num_warmup_steps,
                                            num_training_steps=num_training_steps)

In [None]:
scaler = GradScaler()
epochs = 5

## Training

In [None]:
def download_credentials():
    if os.path.exists(CREDENTIALS_LOCAL):
        print("[INFO] Credentials.json already exists locally.")
        return
    url = f"https://drive.google.com/uc?id={CREDENTIALS_FILE_ID}"
    print("[INFO] Downloading credentials.json from Drive...")
    gdown.download(url, CREDENTIALS_LOCAL, quiet=False)

def get_drive_service():
    creds = service_account.Credentials.from_service_account_file(
        CREDENTIALS_LOCAL,
        scopes=['https://www.googleapis.com/auth/drive']
    )
    return build('drive', 'v3', credentials=creds)

def list_files(service, query):
    results = service.files().list(q=query, fields="files(id, name)").execute()
    return results.get('files', [])

def download_file(service, file_id, filepath):
    request = service.files().get_media(fileId=file_id)
    fh = io.FileIO(filepath, 'wb')
    downloader = MediaIoBaseDownload(fh, request)
    done = False
    while not done:
        _, done = downloader.next_chunk()
    print(f"[INFO] Downloaded {filepath} from Drive.")

def upload_or_replace_file(service, folder_id, filename):
    query = f"name='{filename}' and '{folder_id}' in parents and trashed=false"
    files = list_files(service, query)

    media = MediaFileUpload(filename, resumable=True)
    if files:
        file_id = files[0]['id']
        service.files().update(fileId=file_id, media_body=media).execute()
        print(f"[INFO] Updated file on Drive: {filename}")
    else:
        metadata = {'name': filename, 'parents': [folder_id]}
        service.files().create(body=metadata, media_body=media).execute()
        print(f"[INFO] Uploaded new file to Drive: {filename}")

def find_latest_epoch(service, folder_id):
    query = f"name contains 'checkpoint_epoch_' and '{folder_id}' in parents and trashed=false"
    files = list_files(service, query)
    max_epoch = 0
    latest_file = None
    for file in files:
        match = re.match(r'checkpoint_epoch_(\d+)\.pt', file['name'])
        if match:
            epoch_num = int(match.group(1))
            if epoch_num > max_epoch:
                max_epoch = epoch_num
                latest_file = file
    return max_epoch, latest_file

def train(model, optimizer, scheduler, criterion, train_loader, val_loader, scaler, epochs_total):
    download_credentials()
    service = get_drive_service()

    start_epoch = 1
    best_accuracy = 0.0

    latest_epoch, latest_ckpt_file = find_latest_epoch(service, CHECKPOINT_FOLDER_ID)

    if latest_epoch > 0:
        print(f"[INFO] Found latest checkpoint on Drive: Epoch {latest_epoch} - downloading...")
        download_file(service, latest_ckpt_file['id'], f"checkpoint_epoch_{latest_epoch}.pt")
        checkpoint = torch.load(f"checkpoint_epoch_{latest_epoch}.pt", map_location=device)
        model.module.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
        scaler.load_state_dict(checkpoint['scaler_state_dict'])
        best_accuracy = checkpoint.get('best_accuracy', 0.0)
        start_epoch = latest_epoch + 1
        print(f"[INFO] Resuming training from epoch {start_epoch}")
    else:
        print("[INFO] No checkpoint found on Drive. Starting fresh training.")

    global_step = 0
    training_start = time.time()

    for epoch in range(start_epoch, epochs_total + 1):
        model.train()
        epoch_start = time.time()
        total_loss = 0
        correct = 0
        total = 0

        train_bar = tqdm(train_loader, desc=f"Epoch {epoch}/{epochs_total} [Train]", leave=True, file=sys.stdout)

        for batch in train_bar:
            input_ids = batch[0].to(device, non_blocking=True)
            attention_mask = batch[1].to(device, non_blocking=True)
            labels = batch[2].to(device, non_blocking=True)

            optimizer.zero_grad()
            with autocast('cuda'):
                outputs = model(input_ids, attention_mask)
                loss = criterion(outputs, labels)

            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
            scheduler.step()

            global_step += 1
            total_loss += loss.item()

            preds = torch.argmax(outputs, dim=1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)

            train_bar.set_postfix({
                'loss': f'{loss.item():.4f}',
                'acc': f'{(correct / total) * 100:.2f}%'
            })

        train_acc = correct / total * 100

        model.eval()
        val_loss = 0
        correct = 0
        total = 0
        all_preds = []
        all_labels = []

        val_bar = tqdm(val_loader, desc=f"Epoch {epoch}/{epochs_total} [Val]", leave=False, file=sys.stdout)

        with torch.no_grad():
            for batch in val_bar:
                input_ids = batch[0].to(device)
                attention_mask = batch[1].to(device)
                labels = batch[2].to(device)

                outputs = model(input_ids, attention_mask)
                loss = criterion(outputs, labels)
                val_loss += loss.item()

                preds = torch.argmax(outputs, dim=1)
                all_preds.extend(preds.cpu().tolist())
                all_labels.extend(labels.cpu().tolist())

                correct += (preds == labels).sum().item()
                total += labels.size(0)

                accuracy = (correct / total * 100) if total > 0 else 0
                precision = precision_score(all_labels, all_preds, average='weighted', zero_division=0) if total > 0 else 0
                recall = recall_score(all_labels, all_preds, average='weighted', zero_division=0) if total > 0 else 0
                f1 = f1_score(all_labels, all_preds, average='weighted', zero_division=0) if total > 0 else 0

                val_bar.set_postfix({
                    'val_loss': f'{loss.item():.4f}',
                    'acc': f'{accuracy:.2f}%',
                    'prec': f'{precision:.2f}',
                    'rec': f'{recall:.2f}',
                    'f1': f'{f1:.2f}'
                })

        avg_val_loss = val_loss / len(val_loader)

        print(f"\nEpoch {epoch}/{epochs_total} | "
              f"Train Loss: {total_loss/len(train_loader):.4f}, "
              f"Train Acc: {train_acc:.2f}% | "
              f"Val Loss: {avg_val_loss:.4f}, "
              f"Val Acc: {accuracy:.2f}%, Prec: {precision:.4f}, Rec: {recall:.4f}, F1: {f1:.4f}\n")

        checkpoint_data = {
            'model_state_dict': model.module.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict(),
            'scaler_state_dict': scaler.state_dict(),
            'epoch': epoch,
            'best_accuracy': best_accuracy,
        }
        ckpt_filename = f"checkpoint_epoch_{epoch}.pt"
        torch.save(checkpoint_data, ckpt_filename)

        upload_or_replace_file(service, CHECKPOINT_FOLDER_ID, ckpt_filename)

        if accuracy > best_accuracy:
            best_accuracy = accuracy
            torch.save(model.module.state_dict(), "best_model.pt")
            upload_or_replace_file(service, CHECKPOINT_FOLDER_ID, "best_model.pt")

    print(f"[INFO] Training complete. Total time: {int((time.time() - training_start)/60)} minutes")



In [None]:
train(model, optimizer, scheduler, criterion, train_loader, val_loader, scaler, epochs)