In [None]:
import sys
sys.path.append("../src")

# Basic imports
import pandas as pd
import numpy as np

# PyTorch imports
import torch.nn as nn
import torch
from torch.utils.data import DataLoader
from transformers import BertTokenizer

# Tokenize sentences
from nltk.tokenize import sent_tokenize

# Utils
from dataset_building import build_dataset
from model import init_model
from trainer import train_epoch

# Progress bar
from tqdm import tqdm
tqdm.pandas()

%load_ext autoreload
%autoreload 2

In [None]:
# Load data. Here, as an example, we are taking
# the small sample from the reuters dataset
df = pd.read_json("../data/reuters_sample.json")

# Clean and tokenize
df.text = df.text.progress_apply(lambda x: x.lower())
df.text = df.text.progress_apply(sent_tokenize)

# Build dataset
dataset = build_dataset(df.text,
                masking_percentage=0.5,
                max_pairs_per_doc = 2)

In [None]:
# Load model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = init_model(device)

# Load tokenizer
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")

# Loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)

In [None]:
# Training
epochs = 3

for epoch in range(epochs):
    print(f"Epoch {epoch}")
    
    # Construct DataLoader
    dataloader = DataLoader(dataset[:10],
                            batch_size = 16, 
                            shuffle = True)

    # Train for 1 epoch
    train_epoch(model = model,
                tokenizer = tokenizer,
                dataloader = dataloader,
                optimizer = optimizer,
                criterion = criterion,
                device = device,
                print_each = 500, 
                disable_progress_bar = False)

    # Save model weights after epoch
    save_path = f"saved_models/model_{epoch}epoch.pt"
    torch.save(model.state_dict(), save_path)
    print("Model saved.\n\n")
    