In [1]:
import os
from os.path import exists
from typing import Dict, List, Optional
from collections import Counter
import csv
import torch
from torch import nn, Tensor
import torch.optim as optim
from torch.nn.functional import log_softmax, pad
import math
import copy
import time
from tqdm import tqdm
import torchmetrics
from torch.optim.lr_scheduler import LambdaLR
import pandas as pd
from torchtext.data.functional import to_map_style_dataset
from torch.utils.data import DataLoader
from torchtext.vocab import build_vocab_from_iterator
import torchtext.datasets as datasets
import spacy
import GPUtil
import warnings
from torch.utils.data.distributed import DistributedSampler
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.nn import TransformerEncoder, TransformerEncoderLayer
import torch.nn.functional as F
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer

from LanguageDataset import LanguageDataset
from utils import load_raw_data, predict

In [2]:
main_folder =  './processed_data/'
aymara_folder = main_folder + 'aymara/'

torch.manual_seed(42)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_name = "facebook/nllb-200-distilled-600M"

print("Model Loading . . . . . . . . . . . . . . . .")
model = AutoModelForSeq2SeqLM.from_pretrained(model_name).to(device)
print("Model Loaded")

Model Loading . . . . . . . . . . . . . . . .
Model Loaded


In [7]:
def train(
    model: nn.Module, 
    dataloader: DataLoader,
    tokenizer: AutoTokenizer, 
    optimizer: optim.Optimizer,
    device: torch.device,
    epoch: int,
):
    # acc_metric NOT implimented
    loss_metric = torchmetrics.MeanMetric(compute_on_step=False).to(device)
    model.train()
    
    # loop through all batches in the training
    for batch in tqdm(dataloader):
        input_ids, input_mask, tags = batch[0].to(device), batch[1].to(device), batch[2].to(device)
        batch_size, seq_len = input_ids.size()
        input_ids = input_ids.t().view(seq_len, batch_size)
        tags = tags.t().view(seq_len, batch_size)
        optimizer.zero_grad()

        logits = model.generate(**inputs, max_length=256,
                                forced_bos_token_id=tokenizer.lang_code_to_id["spa_Latn"])

        loss = cross_entropy_loss(logits, tags)
        
        loss.backward()
        nn.utils.clip_grad_norm_(model.parameters(), 0.5)
        optimizer.step()
        
        loss_metric.update(loss, input_mask.numel() - input_mask.sum())
        is_active = torch.logical_not(input_mask.t().view(seq_len, batch_size))  # non-padding elements
        # only consider non-padded tokens when calculating accuracy
    
    print(f"| Epoch {epoch} | loss {loss_metric.compute():.4f}")

In [9]:
aymara_dev_raw = load_raw_data(aymara_folder + 'dev.aym')
aymara_dev_raw['src_text'] = aymara_dev_raw['src_text'][:50]

tokenizer = AutoTokenizer.from_pretrained(model_name, src_lang="ayr_Latn")

aymara_dev_data = LanguageDataset(aymara_dev_raw)
aymara_dev_dataloader = DataLoader(aymara_dev_data, batch_size = 2)


In [10]:
preds = predict(model, aymara_dev_dataloader, AutoTokenizer.from_pretrained(model_name, src_lang="ayr_Latn"), device)


100%|██████████| 25/25 [00:15<00:00,  1.60it/s]


In [16]:
inputs = tokenizer(aymara_dev_raw['src_text'][:2], padding='max_length', truncation=True, max_length=256, return_tensors="pt").to(device)
logits = model.generate(**inputs, max_length=256, forced_bos_token_id=tokenizer.lang_code_to_id["spa_Latn"])


In [18]:
logits

tensor([[     2, 256161,  66948,   1125,  31915,      2,      1,      1,      1,
              1,      1,      1,      1,      1,      1,      1,      1,      1],
        [     2, 256161,   1446,   2048,  26192, 237854,    356,    629,   5766,
            153,    336,  51015,     79,  18442,  80826,  18550, 248075,      2]],
       device='cuda:0')

In [None]:
with open(aymara_folder+"pretrain_result.txt", "w") as f:
    for token_ids in preds:
        text = tokenizer.batch_decode(token_ids, skip_special_tokens=True)
        f.write(" ".join(text) + "\n")