In [1]:
from torch.utils.data import DataLoader
import numpy as np
import torch
import tqdm
import time

In [2]:
from sklearn.metrics import precision_score, recall_score, f1_score

In [3]:
from language_detection.data import load_wili_2018_dataset, BytesDataset, batch_collate_function, get_mask_from_lengths

In [4]:
wili_2018_data_path = "/home/derek/PythonProjects/language_detection/datasets/WiLi_2018"

In [5]:
wiki_dataset = load_wili_2018_dataset(wili_2018_data_path)

[32m2023-11-26 18:20:17.601[0m | [1mINFO    [0m | [36mlanguage_detection.data.loaders[0m:[36mload_wili_2018_dataset[0m:[36m67[0m - [1m'drop_duplicates' is true, dropping duplicates from *training* set...[0m
[32m2023-11-26 18:20:17.661[0m | [1mINFO    [0m | [36mlanguage_detection.data.loaders[0m:[36mload_wili_2018_dataset[0m:[36m74[0m - [1mdropped 3117 samples from training data that also appeared in the test data[0m


In [6]:
len(set(wiki_dataset.x_train).intersection(set(wiki_dataset.x_test)))

0

In [7]:
wiki_dataset.__annotations__

{'x_train': list[str],
 'x_test': list[str],
 'y_train': list[str],
 'y_test': list[str],
 'idx2lang': dict[int, str],
 'lang2idx': dict[str, int],
 'labels': dict[str, str] | None,
 'dropped': list[str] | None}

In [8]:
wiki_dataset.x_train[:3]

['Klement Gottwaldi surnukeha palsameeriti ning paigutati mausoleumi. Surnukeha oli aga liiga hilja ja oskamatult palsameeritud ning hakkas ilmutama lagunemise tundemärke. 1962. aastal viidi ta surnukeha mausoleumist ära ja kremeeriti. Zlíni linn kandis aastatel 1949–1989 nime Gottwaldov. Ukrainas Harkivi oblastis kandis Zmiivi linn aastatel 1976–1990 nime Gotvald.',
 'Sebes, Joseph; Pereira Thomas (1961) (på eng). The Jesuits and the Sino-Russian treaty of Nerchinsk (1689): the diary of Thomas Pereira. Bibliotheca Instituti historici S. I., 99-0105377-3 ; 18. Rome. Libris 677492',
 'भारतीय स्वातन्त्र्य आन्दोलन राष्ट्रीय एवम क्षेत्रीय आह्वान, उत्तेजनासभ एवम प्रयत्नसँ प्रेरित, भारतीय राजनैतिक सङ्गठनद्वारा सञ्चालित अहिंसावादी आ सैन्यवादी आन्दोलन छल, जेकर एक समान उद्देश्य, अङ्ग्रेजी शासनक भारतीय उपमहाद्वीपसँ जडीसँ उखाड फेकनाई छल। ई आन्दोलनक शुरुआत १८५७ मे भेल सिपाही विद्रोहक मानल जाइत अछि। स्वाधीनताक लेल हजारो लोग अपन प्राणक बलि देलक। भारतीय राष्ट्रीय कांग्रेस १९३० कांग्रेस अधिवेशन मे अङ्

In [9]:
max_length = 1024

In [10]:
char_byte_sequences = [memoryview(bytes(c, encoding="utf8")).tolist() for c in wiki_dataset.x_train[2]]
char_seq_lens = [len(c) for c in char_byte_sequences]
if sum(char_seq_lens) > max_length:
    for idx in reversed(range(len(char_seq_lens))):
        subtotal = sum(char_seq_lens[:idx])
        if subtotal <= max_length -2:
            char_byte_sequences = char_byte_sequences[:idx]
            break
byte_sequence = [byte_val for char in char_byte_sequences for byte_val in char]
    



In [11]:
class TransformerClassifier(torch.nn.Module):
    def __init__(self):
        """
        joint classification and masked language model transformer encoder
        """
        super().__init__()
        
        self.num_features = 256+4
        self.num_classes = 235
        self.transformer_layer_count = 4
        self.ffn_dims = 1024
        self.output_dims = 512
        self.attn_heads = 4
        
        self.embedding = torch.nn.Embedding(num_embeddings=self.num_features, embedding_dim=self.output_dims)
        self.transformer_layers = torch.nn.ModuleList([
            torch.nn.TransformerEncoderLayer(
                d_model=self.output_dims, 
                nhead=self.attn_heads, 
                dim_feedforward=self.ffn_dims, 
                activation="gelu", 
                batch_first=True
            ) for _ in range(self.transformer_layer_count)
        ])
        self.clf_layer = torch.nn.Linear(in_features=self.output_dims, out_features=self.num_classes)
        self.mlm_layer = torch.nn.Linear(in_features=self.output_dims, out_features=self.num_features)

    def forward(self, x, pad_mask):
        x = self.embedding(x)
        for lyr in self.transformer_layers:
            x = lyr(x, src_key_padding_mask=pad_mask)
        mlm_preds = self.mlm_layer(x)
        clf_preds = self.clf_layer(x[:, 0, :])
        return clf_preds, mlm_preds

In [12]:
model = TransformerClassifier()
_ = model.to("cuda")
for name, param in model.named_parameters():
    if not str(param.device).startswith("cuda"):
        print(f"param '{name}' is on device '{param.device}'")
print(f"all unmentioned params on cuda!")


all unmentioned params on cuda!


In [13]:
from zmq import device

train_dev_split = int(len(wiki_dataset.x_train) * 0.9)

train_dataset = BytesDataset(
    texts=wiki_dataset.x_train[:train_dev_split], 
    languages=wiki_dataset.y_train[:train_dev_split], 
    mapping=wiki_dataset.lang2idx, 
    max_length=1024, 
    is_training=True
)
dev_dataset = BytesDataset(
    texts=wiki_dataset.x_train[train_dev_split:], 
    languages=wiki_dataset.y_train[train_dev_split:], 
    mapping=wiki_dataset.lang2idx, 
    max_length=1024, 
    is_training=True
)
test_dataset = BytesDataset(
    texts=wiki_dataset.y_test, 
    languages=wiki_dataset.y_test, 
    mapping=wiki_dataset.lang2idx, 
    max_length=1024, 
    is_training=False
)
print(f"train: {len(train_dataset)}, dev: {len(dev_dataset)}, test: {len(test_dataset)}")

train: 102944, dev: 11439, test: 117500


In [14]:
train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=0, collate_fn=batch_collate_function)
dev_dataloader = DataLoader(dev_dataset, batch_size=32, shuffle=False, num_workers=0, collate_fn=batch_collate_function)
test_dataloader = DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers=0, collate_fn=batch_collate_function)

In [15]:
total_epochs = 10
accumulate_steps = 4
global_steps = 0

mlm_criterion = torch.nn.CrossEntropyLoss(reduction="sum", ignore_index=-1)
clf_criterion = torch.nn.CrossEntropyLoss(reduction="sum")
optimizer = torch.optim.AdamW(params=model.parameters(), lr=0.0001)
scheduler = torch.optim.lr_scheduler.OneCycleLR(
    optimizer, 
    max_lr=0.001, 
    epochs=total_epochs,
    steps_per_epoch=len(train_dataloader)//accumulate_steps, 
    pct_start=0.1
)
for epoch in range(total_epochs):
    train_iterator = iter(train_dataloader)
    epoch_losses = []
    model.train()
    print(f"epoch {epoch+1} training")
    time.sleep(0.5)
    for batch_idx, minibatch in enumerate(pbar := tqdm.tqdm(train_iterator, total=len(train_dataloader))):
        # format data and move to gpu
        x, y, seq_lens, mask_indices, targets = minibatch
        x = x.to("cuda")
        y = y.to("cuda")
        targets = targets.to("cuda")
        pad_mask = get_mask_from_lengths(seq_lens, 1024, x.device)
        # model forward
        clf_logits, mlm_logits = model.forward(x, pad_mask)
        # calculate losses
        masked_y = -1 * torch.ones_like(y).to("cuda")
        for i in range(y.shape[0]):
            masked_y[i, mask_indices[i].long()] = y[i, mask_indices[i].long()]
        mlm_loss = mlm_criterion(torch.transpose(mlm_logits, 1, 2), masked_y)
        clf_loss = clf_criterion(clf_logits, targets)
        global_steps += 1
        # windowed loss display
        epoch_losses.append((clf_loss.item(), mlm_loss.item()))
        clf_losses = np.mean([l[0] for l in epoch_losses][-5:])
        mlm_losses = np.mean([l[1] for l in epoch_losses][-5:])
        pbar.set_description(f"step {global_steps}: clf: {clf_losses:.3f}, mlm: {mlm_losses:.3f}")
        # gradient accumulation
        mlm_loss /= accumulate_steps
        clf_loss /= accumulate_steps
        ttl_loss = mlm_loss + clf_loss
        ttl_loss.backward()
        if (batch_idx > 0 and batch_idx % accumulate_steps == 0) or (batch_idx == len(train_dataloader)):
            optimizer.step()
            scheduler.step()
            optimizer.zero_grad()
    
    # dev eval
    print(f"epoch {epoch+1} dev eval")
    time.sleep(0.5)
    model.eval()
    epoch_dev_loss = []
    epoch_targets = []
    epoch_predictions = []
    with torch.no_grad():
        dev_iterator = iter(dev_dataloader)
        for batch_idx, minibatch in enumerate(pbar := tqdm.tqdm(dev_iterator, total=len(dev_iterator))):
            # format data and move to gpu
            x, y, seq_lens, mask_indices, targets = minibatch
            x = x.to("cuda")
            y = y.to("cuda")
            targets = targets.to("cuda")
            pad_mask = get_mask_from_lengths(seq_lens, 1024, x.device)
            # model forward
            clf_logits, mlm_logits = model.forward(x, pad_mask)
            # calculate losses
            masked_y = -1 * torch.ones_like(y).to("cuda")
            for i in range(y.shape[0]):
                masked_y[i, mask_indices[i].long()] = y[i, mask_indices[i].long()]
            mlm_loss = mlm_criterion(torch.transpose(mlm_logits, 1, 2), masked_y)
            clf_loss = clf_criterion(clf_logits, targets)
            epoch_dev_loss.append(clf_loss.item())
            epoch_targets += targets.detach().cpu().numpy().tolist()
            epoch_predictions += clf_logits.max(1).indices.detach().cpu().numpy().tolist()
        print(f"dev micro prc: {precision_score(epoch_targets, epoch_predictions, average='micro')}")
        print(f"dev micro rcl: {recall_score(epoch_targets, epoch_predictions, average='micro')}")
        print(f"dev micro f1b: {f1_score(epoch_targets, epoch_predictions, average='micro')}")
        break
    break

    

epoch 1 training


step 2: clf: 90.068, mlm: 6734.601:   0%|          | 1/6434 [00:00<1:15:46,  1.41it/s]

torch.Size([16, 1024]) torch.Size([16, 1024]) torch.Size([16, 1024]) torch.Size([16])


step 244: clf: 70.579, mlm: 3189.291:   4%|▍         | 244/6434 [00:59<25:12,  4.09it/s]


KeyboardInterrupt: 