In [None]:
import time
import torch
import torchvision
import torch.optim
from torch.utils.data import Dataset, DataLoader
import torch.utils.data
import pandas as pd

: 

In [23]:
# Vocabulary class
class Vocabulary:
    def __init__(self, freq_threshold: int):
        '''any word that appears below freq_threshold number of times will not be included in the vocabulary'''
        self.freq_threshold = freq_threshold
        # index to string
        self.itos = { 0: "<PAD>", 1: "<SOS>", 2: "<EOS>", 3: "<UNK>" }
        # string to index
        self.stoi = { "<PAD>": 0, "<SOS>": 1, "<EOS>": 2, "<UNK>": 3 }
        
    def __len__(self):
        return len(self.itos)

    @staticmethod
    def tokenizer(text: str) -> list[str]:
        # TODO: tokenize text as done in preprocessing
        return text.split()

    def build_vocabulary(self, sentence_list: list[str]):
        frequencies = {}
        i = len(self.itos) # currently 4

        for sentence in sentence_list:
            print(sentence)
            for word in self.tokenizer(sentence):
        
                if word not in frequencies:
                    frequencies[word] = 1
                else:
                    frequencies[word] += 1
                if frequencies[word] >= self.freq_threshold:
                    self.stoi[word] = i
                    self.itos[i] = word
                    i += 1

    def numericalize(self, text: str) -> list[int]:
        tokenized_text = self.tokenizer(text)
        return [
            self.stoi[token] if token in self.stoi else self.stoi["<UNK>"] for token in tokenized_text
        ]

In [21]:
vocab = Vocabulary(0)
sentences = ["hello this is a sentece", "this is another sentence"]
vocab.build_vocabulary(sentences)

print(vocab.itos)
print(vocab.stoi)
print(vocab.numericalize("this is sentence"))

hello this is a sentece
this is another sentence
{0: '<PAD>', 1: '<SOS>', 2: '<EOS>', 3: '<UNK>', 4: 'hello', 5: 'this', 6: 'is', 7: 'a', 8: 'sentece', 9: 'this', 10: 'is', 11: 'another', 12: 'sentence'}
{'<PAD>': 0, '<SOS>': 1, '<EOS>': 2, '<UNK>': 3, 'hello': 4, 'this': 9, 'is': 10, 'a': 7, 'sentece': 8, 'another': 11, 'sentence': 12}
[9, 10, 12]


In [22]:
# Dataset class
class MyDataset(Dataset):
    def __init__(self, file_path: str, freq_threshold=5):
        self.dir = dir
        self.df = pd.read_csv(file_path)

        self.paragraphs = self.df["paragraph"]
        self.titles = self.df["title"]

        self.vocab = Vocabulary(freq_threshold)
        self.vocab.build_vocabulary(self.paragraphs.tolist() + self.titles.tolist())
        
        self.word_to_paragraph = {}

        paragraph_index = 0 # paragraph that this word in title belongs to
        word_rank = 0 # position (index) in the current title
        for word_index, word in enumerate(self.titles):
            self.word_to_paragraph[word_index] = (paragraph_index, word_rank)
            if word == "<EOS>":
                paragraph_index += 1
                word_rank = 0
            else:
                word_rank += 1
            

    def __len__(self):
        return len(self.df)

    def __getitem__(self, index):
        paragraph_index = self.word_to_paragraph[index]
        paragraph, cur_rank = self.paragraphs[paragraph_index]
        next_word = self.titles[index]
        title_so_far = self.titles[index-cur_rank:index]

        # TODO: convert title to one-hot encoded before returning

        return (paragraph, title_so_far), next_word

In [10]:
def get_data_loader(
    file_path: str,
    dataset: MyDataset,
    freq_threshold=5,
    batch_size=32,
    num_workers=8,
    splits=[0.8,0.1,0.1]):
    '''
    dataset: torchvision.datasets a transformer dataset for training, testing, and validation
    batch_size: int
    splits: list(str) train-validation-test split
    return: DataLoader
    '''
    dataset = TrainDataset(file_path, freq_threshold)

    assert sum(splits) == 1, "ensure sum of train-validation-test split adds up to 1"

    # perform split
    size = len(dataset)
    l1, l2 = int(size*splits[0]), int(size*splits[1])
    l3 = size - l1 - l2

    train_set, val_set, test_set = torch.utils.data.random_split(
        dataset,
        [l1, l2, l3],
        generator=torch.Generator().manual_seed(999)
    )

    # get data loaders
    train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=num_workers)
    val_loader = DataLoader(val_set, batch_size=batch_size, shuffle=True, num_workers=num_workers)
    test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=True, num_workers=num_workers)

    return train_loader, val_loader, test_loader

In [11]:
def train(data_loader: DataLoader, model, loss_function, optimizer, scheduler=None, epochs=30): 
    losses_over_epochs = []
    num_batches = len(data_loader)

    for epoch in epochs:
        start = time.time()
        total_loss = 0
        for (paragraphs, titles) in data_loader:
            # forward step
            out = model(paragraphs)

            # loss
            loss = loss_function(out, titles)
            total_loss += loss.item()

            # back propagation
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        
        # learning rate scheduler update
        if scheduler is not None:
            scheduler.step()

        # finished one epoch of training
        end = time.time()
        print(f"Completed epoch {epoch+1} | average loss: {total_loss/num_batches} | time: {end-start}s")
        losses_over_epochs.append(total_loss/num_batches)

    return losses_over_epochs