In [None]:
# !pip install torchtext
is_kaggle = True
working_dir = '/kaggle/input/' if is_kaggle else ''

In [None]:
from kaggle_secrets import UserSecretsClient
user_secrets = UserSecretsClient()
secret_value_0 = user_secrets.get_secret("wanb")

import wandb
wandb.login(key=secret_value_0)

In [None]:
import pandas as pd
import numpy as np
import torch
import torch.nn.functional as F
import torch.nn as nn
import torchtext
from torch.utils.data import DataLoader
from sklearn.model_selection import train_test_split
from torch.utils.data import Dataset

from tqdm.auto import tqdm

## Text preprocessing

In [None]:
def str_to_list(value):
    list_values = value.strip('[]').split(', ')
    cleaned_list_values = [item[1:-1] for item in list_values]
    return cleaned_list_values

In [None]:
import os
os.listdir('/kaggle/input/aml-dataset')

In [None]:
# dataset = pd.read_csv("../datasets/tonetags_wsd_1.csv", index_col=0, converters={"text": str_to_list})
dataset = pd.read_csv("/kaggle/input/aml-dataset/tonetags_wsd_1.csv", index_col=0, converters={"text": str_to_list})

In [None]:
# dataset.tags.nunique()/len(dataset)
dataset.tags.value_counts()/len(dataset)

In [None]:
labels = dataset.tags.unique().tolist()
dataset.tags = dataset.tags.apply(labels.index)

In [None]:
vocab = torchtext.vocab.GloVe(name='6B', dim=50).stoi
vocab["<unk>"] = len(vocab)
vocab["<pad>"] = len(vocab)

In [None]:
max_length = 4096

In [None]:
class myDataset(Dataset):
    def __init__(self, dataset):
        self.data = []
        for sentence in dataset.text:
            if len(sentence) > max_length:
                continue
            sentence_ids = []
            for token in sentence:
                try:
                    sentence_ids.append(vocab[token])
                except KeyError:
                    sentence_ids.append(vocab["<unk>"])
            self.data.append(sentence_ids)
        self.labels = dataset.tags
        
    def __len__(self):
        return len(self.data)
        
    def __getitem__(self, idx):
        return self.data[idx], torch.tensor(self.labels.iloc[idx])

In [None]:
def collate_fn(batch):
    data_ids = []
    labels = []
    for dat in batch:
        data_ids.append(dat[0])
        labels.append(dat[1])
    
    for i in range(len(data_ids)):
        while len(data_ids[i]) < max_length:
            data_ids[i].append(vocab["<pad>"])
    
    return data_ids, labels

# def collate_fn(batch):
#     data_ids = []
#     labels = []
#     for dat in batch:
#         data_ids.append(dat[0])
#         labels.append(dat[1])
    
#     for i in range(len(data_ids)):
#         while len(data_ids[i]) < max_length:
#             data_ids[i].append(vocab["<pad>"])
            
    
#     return data_ids, labels

In [None]:
train, test = train_test_split(dataset, test_size=0.2, shuffle=False)

In [None]:
train_dataset = myDataset(train)
test_dataset = myDataset(test)

In [None]:
batch_size = 256

In [None]:
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
test_dataloader = DataLoader(test, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)

In [None]:
# for batch in tqdm(train_dataloader):
#     embedded_tokens = batch[0]
#     labels = batch[1]

In [None]:
vec = torchtext.vocab.GloVe('6B', dim=50).vectors.numpy()
vec = np.append(vec, np.zeros(50)).reshape(-1, 50)
vec = np.append(vec, np.ones(50)).reshape(-1, 50)

In [None]:
embed_tensor = torch.tensor(vec, dtype=torch.float)

In [None]:
embed = nn.Embedding.from_pretrained(embed_tensor, freeze=True)

In [None]:
# a = list(train_dataloader)[0]

In [None]:
# embed()

## Train utils

In [None]:
def train_epoch(trainloader,model,opt,loss_criterion):

        global device
        model.to(device)
        model.train()
        #loss для одной epoch
        train_loss = 0
        total = 0
        
        #применить для каждого бача из trainload

        start_time = datetime.datetime.now()
        
        for batch_idx, (inputs, targets) in enumerate(trainloader):
            #тренировка
            opt.zero_grad()
            inputs = torch.tensor(inputs)
#             inputs = embed(inputs)
            targets = torch.tensor(targets)
            inputs, targets = inputs.to(device), targets.to(device)
            
            outputs = model(inputs).reshape(len(inputs),-1)
            
            
            #шаг оптимизации/loss funct       
            loss = loss_criterion(outputs,targets)
            train_loss += loss.item()
            loss.backward()

            opt.step()
    

            #подсчет точности
            total += targets.size(0)
            
            opt.zero_grad()
            
            del inputs
            del targets
            gc.collect()
            torch.cuda.empty_cache()
        # запись loss/acc для train в классе statistic, и в Tensorboard
        
#         print("epoch---------------")
        print("train")
        print("epoch: time",datetime.datetime.now() - start_time )
        print("loss: ", train_loss/total)
        wandb.log({"loss/train":train_loss/total},step=current_epoch_number)
        
        


def test(testloader,model,loss_criterion):
    global device
    model.to(device)
    model.eval()

    test_loss = 0
    correct = 0
    total = 0

    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(testloader):
            inputs, targets = inputs.to(device), targets.to(device)
            logits = model(inputs)
            outputs = torch.nn.functional.log_softmax(logits,dim=1)

            loss = loss_criterion(outputs, targets)

            test_loss += loss.item()
#             _, predicted = outputs.exp(dim=1).max(1)
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()



    # запись loss/acc для test в классе statistic, и в Tensorboard
    print('test')
    print("loss: ", test_loss/total, "acc: ",correct/total)
    wandb.log({"loss/test":test_loss/total,"acc/test":correct/total})

In [None]:
import datetime

def train(epoch,trainloader,testloader,model,loss,opt):
    global current_epoch_number
    for current_epoch_number in range(epoch):
        
        print("epoch",current_epoch_number)
        train_epoch(trainloader=trainloader,model=model,opt=opt,loss_criterion=loss)
        test(testloader=testloader,model=model,loss_criterion=loss)
        
        outputs = inference(model=model,testloader=testloader)
        targets = get_targets(testloader=testloader)
        evaluate(targets,outputs)
        
        if scheduler is None:
            continue
        scheduler.step()

In [None]:
def evaluate(y_true,y_pred):
    from sklearn.metrics import log_loss
    from sklearn.metrics import roc_auc_score
    
    from sklearn.metrics import f1_score
    from sklearn.metrics import accuracy_score
    
    if (y_true<0).all():
        return
    
    y_pred_class = y_pred.clone()
    y_pred_class[y_pred_class<0.5] = 0
    y_pred_class[y_pred_class>0.5] = 1
    
    result = {"log_loss":log_loss(y_true, y_pred),
              "roc_auc": roc_auc_score(y_true, y_pred),
              "acc": accuracy_score(y_pred_class,y_true),
              "f1":f1_score(y_pred_class,y_true)}

## Model

In [None]:
# input_dim, hidden_dim, layer_dim, output_dim
import gc


In [None]:
configs = {"epoch":20,"optLr":1e-2,"optM":0,"model":"RNN","dataset":"dataset1","loss": "CrossEntropy"}
# configs = {"epoch":20,"opt":"Adam","betas":(0.9,0.99),"optLr":5e-4,"optM":0,"model":"ViT","dataset":"noChange","loss": "BCELoss"}

device = 'cuda' if torch.cuda.is_available() else 'cpu'

class ToneTagsRNN(torch.nn.Module):
    def __init__(self,embedding, vocab_size=400002, hidden_dim=50, output_size=19, num_layers=2, dropout=0.4):
        super(ToneTagsRNN, self).__init__()
        # vocab_size = 400002
        # embedding_dim = 50
        # hidden_dim_lstm = 30
        
        # output_size = 19
        self.rnn_output_size = hidden_dim*max_length*2

        self.embedding = embedding
        self.rnn = torch.nn.RNN(input_size=self.embedding.embedding_dim, hidden_size=hidden_dim, num_layers=num_layers, bidirectional=True, dropout=dropout, batch_first=True)
        
#         (64x6400 and 409600x1024)
#         50*4096*2
        self.fc1 = nn.Linear(hidden_dim*max_length*2, 1024)
        self.fc2 = nn.Linear(1024, 512)
        self.fc3 = nn.Linear(512, output_size)
        # self.out = nn.Softmax(output_size, dim=1)
        
    def forward(self,x):
        activation = torch.nn.ReLU()
        embedded = self.embedding(x)
        output, hidden = self.rnn(embedded)

        rnn_out = output.reshape(-1, self.rnn_output_size)
        

        fc1_out = activation(self.fc1(rnn_out))

        fc2_out = activation(self.fc2(fc1_out))
        out = self.fc3(fc2_out)
#         
        return torch.nn.Softmax()(out)

embedding = nn.Embedding.from_pretrained(embed_tensor, freeze=True)
# vocab_size = len(vocab)
model = ToneTagsRNN(embedding=embedding)


# opt = torch.optim.SGD(params=model.parameters(),lr=configs['optLr'])
opt = torch.optim.Adam(params=model.parameters(),lr=configs['optLr'])
loss = torch.nn.CrossEntropyLoss()

wandb.init(config=configs,
           project="AML", 
           name='RNN_init')

train(epoch=configs['epoch'],trainloader=train_dataloader,testloader=test_dataloader,model=model,loss=loss,opt=opt)
