In [3]:
import math
import pickle
from matplotlib import pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
import pandas as pd
import ../constants
from sklearn.model_selection import KFold
from torch.utils.data import DataLoader, SubsetRandomSampler
from model import TextClassifier
from torch.optim import AdamW
from torch.optim.lr_scheduler import OneCycleLR

from dataset.news_dataset import NewsDataset
from dataset.preprocessing import collate_batch, get_vocab_size

from vissualize.plot import plot_learning_curve, plot_gradients, plot_ud_ratios

from torch.cuda.amp import autocast, GradScaler

from colorama import Fore, Style, init



def fine_tune(df, device, epochs=20, percent_to_train=0.5):
    """Fine-tune the model on the scraped data."""
    print(f"{Fore.YELLOW}Fine-tuning the model on {device}...{Style.RESET_ALL}")
    # Load the model
    model = TextClassifier(get_vocab_size(), embed_dim, num_class, num_heads=num_heads, dropout_rate=dropout_rate, layer_size=layer_size, number_of_layers=number_of_layers)
    model.load_state_dict(torch.load('./results/best_model.pth'))
    model = model.to(device)

    
    # Split the data to percent_to_train for training
    df = df.sample(frac=1).reset_index(drop=True) # Shuffle the data
    train_size = int(percent_to_train * len(df))

    df = df[:train_size]

    fine_tune_dataset = NewsDataset(df['body'].reset_index(drop=True), df['Category'].reset_index(drop=True))
    fine_tune_loader = DataLoader(fine_tune_dataset, batch_size=constants.batch_size, collate_fn=collate_batch)
    
    optimizer = AdamW(model.parameters(), lr=constants.lr, weight_decay=1e-2)
    scheduler = OneCycleLR(optimizer, max_lr=1e-2, steps_per_epoch=len(fine_tune_loader), epochs=epochs)
    loss_fn = nn.CrossEntropyLoss()

    best_loss = float('inf')

    model.train()
    total_loss, total_count = 0.0, 0
    for epoch in range(epochs):
        for X, y in fine_tune_loader:
            X, y = X.to(device), y.to(device)
            pred = model(X)
            loss = loss_fn(pred, y)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            total_loss += loss.item() * X.size(0)
            total_count += X.size(0)
            if total_loss < best_loss:
                best_loss = total_loss
                torch.save(model.state_dict(), './results/fine_tuned_model.pth')
        print(f'Epoch {epoch+1}/{epochs}: Training Loss = {total_loss / total_count:.6f}')
        scheduler.step()

SyntaxError: invalid syntax (1301273796.py, line 8)