In [None]:
import json
from pathlib import Path
from dataset import *
import numpy as np
import pandas as pd
from torch.utils.data import DataLoader, Dataset
from model import *
from tqdm import tqdm
import sys
import os
from metrics import *
import torch
import argparse
import copy


import pynvml
pynvml.nvmlInit()
handle = pynvml.nvmlDeviceGetHandleByIndex(0)   #gpu_id
meminfo = pynvml.nvmlDeviceGetMemoryInfo(handle)


In [None]:
if not os.path.exists("./outputs"):
    os.mkdir("./outputs")
data_dir = Path('..//input/')

train_mark_path = './data/train_mark.csv'
train_features_path = "./data/train_fts.json"
val_mark_path = './data/val_mark.csv'
val_features_path = './data/val_fts.json'
val_path = "./data/val.csv"
model_name_or_path ="microsoft/codebert-base"

md_max_len = 512
total_max_len = 48
batch_size= 1
accumulation_steps=1
epochs=1
n_workers=4

In [None]:
train_df_mark = pd.read_csv(train_mark_path).drop("parent_id", axis=1).dropna().reset_index(drop=True)
train_fts = json.load(open(train_features_path))
val_df_mark = pd.read_csv(val_mark_path).drop("parent_id", axis=1).dropna().reset_index(drop=True)
val_fts = json.load(open(val_features_path))
val_df = pd.read_csv(val_path)

order_df = pd.read_csv("../input/train_orders.csv").set_index("id")
#df_orders = pd.read_csv(
#    data_dir / 'train_orders.csv',
#    index_col='id',
#    squeeze=True,
#).str.split()

order = pd.read_csv(
    data_dir / 'train_orders.csv'
)
train_path = './data/train.csv'
val_path = './data/val.csv'
val_df = pd.read_csv(val_path)
train_df = pd.read_csv(train_path)

full_df = pd.concat([train_df, val_df])

In [None]:
# tmp = train_df_mark[train_df_mark['cell_type']=='markdown'].groupby("id").count()['cell_id']
# tmp.apply(lambda x: x**2).sum()/len(tmp) # = 539
# tmp.mean() # = 15.56
# train_df.head()

In [None]:
def get_dict(df):
    md_dict = dict(zip(df["cell_id"].values, df['source'].values))
    return md_dict

md_dict = get_dict(full_df[full_df['cell_type']=='markdown'])
cd_dict = get_dict(full_df[full_df['cell_type']=='code'])

In [None]:
#len(md_dict), len(cd_dict)
#len(train_df_mark), len(train_df), len(val_df_mark), len(val_df)
#order.head()

In [None]:
%time
x = 1<<32
seed = 42
def transform(row):
    
    cell_ids = list(row["cell_order"].split())
    md_ids = []
    cell_shuffle = []
    md_mask = []
    for cell_id in cell_ids:
        if cell_id in md_dict:
            cell_shuffle.append(0)
            md_ids.append(cell_id)
            md_mask.append(1)
        else:
            cell_shuffle.append(cell_id)
            md_mask.append(0)
    length = len(md_ids)
    _hash = hash(row["cell_order"])*seed % x
    permutation = np.arange(length)
    #permutation = np.random.RandomState(seed=_hash).permutation(length)
    i = 0
    for j in range(len(cell_shuffle)):
        if cell_shuffle[j] == 0:
            cell_shuffle[j] = cell_ids[permutation[i]]
            i+=1
    
    return pd.Series([row.id, cell_shuffle, permutation, md_mask, len(md_mask)], index=['id', 'cell_shuffle', 'permutation', 'md_mask', 'md_len'])

order = order.apply(transform, axis=1)
order.head()

In [None]:
tr_ids =   train_df["id"].unique()
train_order = order[order["id"].isin(tr_ids)]
val_order  = order[~(order["id"].isin(tr_ids))]
train_df = train_df.set_index("cell_id", drop=True)
val_df = val_df.set_index("cell_id", drop=True)

In [None]:
# train_order.head()
# val_order.head()
# train_df.head()

## test

In [None]:
#order.head()
#row = order.iloc[1]
#for cell_id in row.cell_shuffle:
#    print(cell_id)
#    print(train_df.loc[cell_id].source)
#    print("\n"*5)

In [None]:
#tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
#row = train_order.iloc[1]
#cells = row.cell_shuffle
#[train_df.loc[cell_id].source for cell_id in cells]
#inputs = tokenizer.batch_encode_plus(
#            [train_df.loc[cell_id].source for cell_id in cells],
#            add_special_tokens=True,
#            max_length=total_max_len,
#            padding="max_length",
#            # return_token_type_ids=True,
#            truncation=True
#        )
#ids = torch.LongTensor(inputs['input_ids'])
#mask = torch.LongTensor(inputs['attention_mask'])
#md_mask = torch.LongTensor(row.md_mask)
#permutation = torch.LongTensor(row.permutation)

        
#print("ids:", ids.size())
#print("mask:", mask.size())
#print("md_mask:", md_mask.size())

In [None]:
class MarkdownDataset(Dataset):
    def __init__(self, order_df, df, model_name_or_path, total_max_len, md_max_len, fts):
        super().__init__()
        self.order_df = order_df # .sort_values(by=['md_len','id'], ascending=False)
        self.df = df
        self.md_max_len = md_max_len
        self.total_max_len = total_max_len  # maxlen allowed by model config
        self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
        self.fts = fts

    def __getitem__(self, index):
        row = self.order_df.iloc[index]
        cells = row.cell_shuffle

        # print("index: "+ str(index))

        inputs = self.tokenizer.batch_encode_plus(
            [str(self.df.loc[cell_id].source) for cell_id in cells],
            # None,
            add_special_tokens=True,
            max_length=self.total_max_len,
            padding="max_length",
            # return_token_type_ids=True,
            truncation=True
        )

        ids = torch.LongTensor(inputs['input_ids'])
        mask = torch.LongTensor(inputs['attention_mask'])
        md_mask = torch.LongTensor(row.md_mask)
        permutation = torch.LongTensor(row.permutation)
    
        length = len(md_mask)
        
        #print("ids:", ids.size())
        #print("mask:", mask.size())
        #print("md_mask:", md_mask.size())
        #return torch.FloatTensor([length])
        return ids, mask, permutation, md_mask, length

    def __len__(self):
        return self.order_df.shape[0]

In [None]:
train_ds = MarkdownDataset(train_order, train_df, model_name_or_path=model_name_or_path, md_max_len=md_max_len,
                           total_max_len=total_max_len, fts=train_fts)
val_ds = MarkdownDataset(val_order, val_df, model_name_or_path=model_name_or_path, md_max_len=md_max_len,
                         total_max_len=total_max_len, fts=val_fts)
# 每次shuffle均不同
train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=n_workers,
                          pin_memory=False, drop_last=True)
val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False, num_workers=n_workers,
                        pin_memory=False, drop_last=False)

In [26]:
# model.cpu()
HIDDEN_SIZE = 768
class BersonModel(nn.Module):
    def __init__(self, model_path):
        super(BersonModel, self).__init__()
        self.model = AutoModel.from_pretrained(model_path)
        
        self.dropout = nn.Dropout(0.1)
        self.decoder = nn.LSTM(input_size=HIDDEN_SIZE, hidden_size=HIDDEN_SIZE, num_layers=1, batch_first=True)
        self.top = nn.Linear(HIDDEN_SIZE+2, 1)
        self.softmax = nn.Softmax(dim=0)
        self.layernorm = nn.LayerNorm(2)
        self.loss = nn.BCELoss()


    def forward(self, ids, mask, order, md_mask, length):
        """
        :param ids:
        :param mask:
        :param orders:
        :param type: 0 for code, 1 for md
        :return:
        """
        ids = ids.squeeze(0) # [:8,:]
        mask = mask.squeeze(0) # [:8,:]
        order = order.squeeze(0)
        md_mask = md_mask.squeeze(0)
        length = length.squeeze(0)
        # print("ids:", ids.size())
        #print("md_mask:", md_mask.size())
        #rint("length[:", length[0])
        
        
        #print(torch.cuda.memory_allocated(device=0) / (1024 * 1024))#0表示显卡号
        
        """
        print("size:",ids.size()[0])
        def get_batch():
            q = 16
            l = []
            while q < length+16:
                q_ids = ids[q-16:q, :]
                q_mask = mask[q-16:q, :]
                t = self.model(q_ids, q_mask)[1]
                q += 16
                l.append(t)
            return torch.cat(l, 0)
        """
        
        def get_batch():
            return self.model(ids, mask)[0][:,[0],:]
        
        if length > 64:
            with torch.no_grad():
                x = get_batch()
        else:
            x = get_batch()
        del ids, mask
        
        
        x = torch.swapaxes(x, 0, 1)
        # print("x:", x.size())
        
        # hn -> (Layer, batch, hiddenstate)
        cn = torch.sum(x, 1).cuda().unsqueeze(1)
        #print("cn:", cn.size())
        hn = torch.zeros_like(cn).cuda()
        hcn = (hn, cn)
        
        md_len = len(order)
        # print("hn:", hn.size())
        # x = self.model(ids, mask)[1]
        # print(torch.cuda.memory_allocated(device=0) / (1024 * 1024))#0表示显卡号
        
        md_pos = [idx for idx, i in enumerate(md_mask) if i==1]
        # print("md_pos:", md_pos)
        # print("md_mask:", md_mask)
        # print("order:", order)
        md_pool = torch.swapaxes(x[:, md_pos, :], 0, 1)
        # print("md_num:",md_len)
        # print("md_pool:", md_pool.size())
        loss = 0
        j, idx = 0, 0
        
        # step = 1./md_len
        # gt_loc = torch.arange(0, 1+step, step)[:md_len].cuda()
        # assert len(gt_loc) == md_len
        
        for i in range(length):
            if md_mask[i] == 1:
                idx = order[j]  
                # construct hcn for batch
                hn, cn = hcn
                hn = torch.tile(hn, (1, md_len, 1))
                cn = torch.tile(cn, (1, md_len, 1))
                t_hcn = (hn, cn)
                
                t_x, t_hcn = self.decoder(md_pool, t_hcn)
                # t_x: (n, 1, h)
                feats = torch.tensor([i*1.0/length, j*1.0/md_len], dtype=t_x.dtype).unsqueeze(0).unsqueeze(1).cuda()
                feats = torch.tile(feats, (t_x.size()[0], t_x.size()[1],1))
                t_x = torch.cat((t_x,feats), 2)
                #print("t_x:", t_x.size(), "  idx:", idx)
                # torch.cat(torch.tensor(i*1.0/length, dtype=t_x.dtype), 2)
                t_x = self.top(self.dropout(t_x))
                # print("t_x:", t_x.size(), "  idx:", idx)
                output = self.softmax(t_x).squeeze(2).squeeze(1)
                
                gt = torch.zeros_like(output).cuda()
                gt[idx] = 1.
                # print("output:",output)
                # print("gt:", gt)
                loss += self.loss(output, gt)
                # print(torch.square(gt_loc - torch.tensor(j*1.0/md_len, dtype=float).cuda()))
                # loss += torch.sum(torch.square(gt_loc - torch.tensor(j*1.0/md_len, dtype=float).cuda()) * output)
                j += 1
            # print("test:", x[:, [i], :].size())
            output, hcn = self.decoder(x[:, [i], :], hcn)
        
                
        
        ## concat
        loss /= md_len
        # print("loss:", loss)

        # lstm
        #hn = 
        #hcn = 
        #for t in range():
        #    output, hcn = self.decoder(x, hcn) # hcn = (hidden_state, cell_state)
        # loss
        # loss = 1

        ## sentence order

        ## coherence
        return loss
    
    def beam_search(self, ids, mask, order, md_mask, length, k=5):
        """
        :param ids:
        :param mask:
        :param orders:
        :param type: 0 for code, 1 for md
        :return:
        """
        ids = ids.squeeze(0) # [:8,:]
        mask = mask.squeeze(0) # [:8,:]
        order = order.squeeze(0)
        md_mask = md_mask.squeeze(0)
        length = length.squeeze(0)
        
        # print("ids:", ids.size())
        #print("md_mask:", md_mask.size())
        #rint("length[:", length[0])


        #print(torch.cuda.memory_allocated(device=0) / (1024 * 1024))#0表示显卡号

        """
        print("size:",ids.size()[0])
        def get_batch():
          q = 16
          l = []
          while q < length+16:
              q_ids = ids[q-16:q, :]
              q_mask = mask[q-16:q, :]
              t = self.model(q_ids, q_mask)[1]
              q += 16
              l.append(t)
          return torch.cat(l, 0)
        """

        def get_batch():
            return self.model(ids, mask)[0][:,[0],:]
        
        if length > 64:
            with torch.no_grad():
                x = get_batch()
        else:
            x = get_batch()
        del ids, mask
        
        
        x = torch.swapaxes(x, 0, 1)
        # print("x:", x.size())
        
        # hn -> (Layer, batch, hiddenstate)
        cn = torch.sum(x, 1).cuda().unsqueeze(1)
        #print("cn:", cn.size())
        hn = torch.zeros_like(cn).cuda()
        hcn = (hn, cn)

        md_len = len(order)
        # print("hn:", hn.size())
        # x = self.model(ids, mask)[1]
        # print(torch.cuda.memory_allocated(device=0) / (1024 * 1024))#0表示显卡号

        md_pos = [idx for idx, i in enumerate(md_mask) if i==1]
        # (md_len, 1, h)
        md_pool = torch.swapaxes(x[:, md_pos, :], 0, 1)
        loss = 0
        j, idx = 0, 0
        # value, hcn, used
        queue = [(0, hcn, torch.FloatTensor([1]*md_len), [])]
        nxt_queue = []

        for i in range(length):
            if md_mask[i] == 1:
                for score, hcn, mask, res in queue:
                    # construct hcn for batch
                    hcn = (torch.tile(hcn[0], (1, md_len, 1)), torch.tile(hcn[1], (1, md_len, 1)))
                    # masked select for dim 1 
                    # hcn_mask = mask.unsqueeze(0).unsqueeze(2)
                    # hcn_mask = torch.tile(hcn_mask, (hcn[0].size()[0], 1, hcn[0].size()[2]))
                    # print("masked_selected:", torch.masked_select(hcn[0], hcn_mask).size())
                    # hcn = (torch.masked_select(hcn[0], hcn_mask).reshape(hcn[0].size()[0],-1,hcn[0].size()[2]),
                    #        torch.masked_select(hcn[1], hcn_mask).reshape(hcn[0].size()[0],-1,hcn[0].size()[2]))
                    # print("md_pool size：", md_pool.size())
                    
                    # md_pool_mask = mask.unsqueeze(1).unsqueeze(2)
                    # md_pool_mask = torch.tile(md_pool_mask, (1,1,md_pool.size()[2]))
                    # masked select for dim 0
                    # t_x = torch.masked_select(md_pool, md_pool_mask).reshape(-1,md_pool.size()[1],md_pool.size()[2])
                    # print("t_x:", t_x.size(), "  idx:", idx)
                    t_x, hcn = self.decoder(md_pool, hcn)
                    # x: (n, 1, h)
                    feats = torch.tensor([i*1.0/length, j*1.0/md_len], dtype=t_x.dtype).unsqueeze(0).unsqueeze(1).cuda()
                    feats = torch.tile(feats, (t_x.size()[0], t_x.size()[1],1))
                    t_x = torch.cat((t_x,feats), 2)
                    
                    t_x = self.top(self.dropout(t_x)).squeeze(2).squeeze(1) * mask.cuda()
                    # print(t_x)
                    t_x = torch.where(t_x != 0.0, t_x, torch.tensor(-1e3, dtype=t_x.dtype).cuda())
                    output = self.softmax(t_x)
                    
                    # print("output:",output.size(), output)
                    values, indexes = torch.topk(output, min(k, len(mask)))
                    
                    #print(values)

                    for value, index in zip(values, indexes):
                        # print(value,index, type(value), type(index))
                        # index = index.detach().cpu()

                        new_mask = torch.clone(mask)
                        new_mask[index] = 0
                        new_res = copy.deepcopy(res)
                        new_res.append(index)
                        # print("nxt_queue:", index, new_mask)
                        # print("hcn size", hcn[0].size())
                        nxt_queue.append((value+score, (hcn[0][:,[index],:], hcn[1][:,[index],:]), new_mask, new_res))
                j += 1
            else:
                for score, hcn, mask, res in queue:
                    # print("x size:", x.size(), x[:, [i], :].size())
                    t_x, hcn = self.decoder(x[:, [i], :], hcn)
                    nxt_queue.append((score, hcn, mask, res))    
            nxt_queue.sort(key=lambda x:-x[0])
            nxt_queue = nxt_queue[:k]
            queue = nxt_queue
            nxt_queue = []

        return queue

In [27]:
import time
def read_data(d):
    return tuple([data.cuda() for data in d])

model = BersonModel(model_name_or_path)
model = model.cuda()
model.load_state_dict(torch.load("./outputs/model_5000.bin"))

<All keys matched successfully>

In [28]:
# param_optimizer = list(model.named_parameters())
# print([n for n,p in param_optimizer])

In [29]:
np.random.seed(0)
# Creating optimizer and lr schedulers
param_optimizer = list(model.named_parameters())
no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
add = ['decoder, top']
optimizer_grouped_parameters = [
    {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay+add)],'weight_decay': 0.01},
    {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0},
    {'params': [p for n, p in param_optimizer if any(nd in n for nd in add)], "lr":3e-5 , 'weight_decay': 0.01},
]

num_train_optimization_steps = int(epochs * len(train_loader) / accumulation_steps)
print("steps:",num_train_optimization_steps, "acc_step:", accumulation_steps)
optimizer = AdamW(optimizer_grouped_parameters, lr=3e-5,
                  correct_bias=False)  # To reproduce BertAdam specific behavior set correct_bias=False
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=0.05 * num_train_optimization_steps,
                                            num_training_steps=num_train_optimization_steps)  # PyTorch scheduler

criterion = torch.nn.L1Loss()
scaler = torch.cuda.amp.GradScaler()


steps: 125292 acc_step: 1


In [30]:
model.train()
f = open("./outputs/log.dat", "w+")
tbar = tqdm(train_loader,file=f)
loss_list = []
preds = []
labels = []

## next

In [None]:
#if e <= 2:
    #y_val, y_pred = validate(model, val_loader)
    #val_df["pred"] = val_df.groupby(["id", "cell_type"])["rank"].rank(pct=True)
    #val_df.loc[val_df["cell_type"] == "markdown", "pred"] = y_pred
    #y_dummy = val_df.sort_values("pred").groupby('id')['cell_id'].apply(list)
    #print("Preds score", kendall_tau(df_orders.loc[y_dummy.index], y_dummy))
    #continue
#print("epoch:", e)

avg_loss = 0
for idx, data in enumerate(tbar):
    # print(type(data), type(data[0]), data[0].size())
    # print("data[3].size()[1]: ",data[3].size()[1])
    
    data = read_data(data)
    
    #with torch.cuda.amp.autocast():
    loss = model(*data)
        # loss = criterion(pred, target)
    scaler.scale(loss).backward()
    nn.utils.clip_grad_norm_(model.parameters(), max_norm=1)
    if idx % accumulation_steps == 0 or idx == len(tbar) - 1:
        scaler.step(optimizer)
        scaler.update()
        optimizer.zero_grad()
        scheduler.step()
        
    loss_list.append(loss.detach().cpu().item())
    if idx == 0:
        continue

    if idx % 100 == 0:
        avg_loss = np.round(np.mean(loss_list[-1000:]), 4)
        # print("memory allocated:",torch.cuda.memory_allocated(device=0) / (1024 * 1024))
        print("avg_loss:", avg_loss)
    if idx % 5000 == 0:
        torch.save(model.state_dict(), f"./outputs/model_{idx}.bin")

    
    # preds.append(pred.detach().cpu().numpy().ravel())
    # labels.appendabs(target.detach().cpu().numpy().ravel())



    

In [31]:
def validate(model, val_loader):
    model.eval()

    tbar = tqdm(val_loader, file=sys.stdout)

    preds = []
    labels = []

    with torch.no_grad():
        for idx, data in enumerate(tbar):
            inputs = read_data(data)

            with torch.cuda.amp.autocast():
                pred = model.beam_search(*inputs)
                #print(len(pred[0][-2]), pred[0][-1])
                res = [t.detach().cpu().numpy().ravel()[0] for t in pred[0][-1]]
            preds.append(res)
            labels.append(np.arange(len(res)).ravel())
            if idx > 100:
                break
    return labels, preds
a, b = validate(model, val_loader)


  1%|██▋                                                                                                                                                                                                                                                                                                                                                                                 | 101/13964 [00:17<39:09,  5.90it/s]


In [32]:
def getInvCount(arr):
    n = len(arr)
    inv_count = 0
    for i in range(n):
        for j in range(i + 1, n):
            if (arr[i] > arr[j]):
                inv_count += 1
  
    return inv_count

def kendall_tau(preds):
    total_inversions, total_2max = 0, 0 
    for pred in preds:
        n = len(pred)
        print(pred,count_inversions(pred), n * (n - 1))
        total_inversions += count_inversions(pred)
        total_2max += n * (n - 1) 
    return 1 - 4 * total_inversions / total_2max

kendall_tau(b)

[1, 0, 10, 3, 4, 8, 2, 7, 5, 9, 11, 6] 19 132
[0, 12, 13, 2, 20, 5, 18, 4, 8, 7, 15, 11, 14, 6, 1, 3, 19, 16, 17, 9, 10] 91 420
[0, 1, 9, 3, 11, 10, 12, 6, 4, 7, 2, 8, 5] 34 156
[8, 1, 5, 2, 4, 3, 7, 0, 9, 6] 20 90
[8, 20, 0, 3, 5, 12, 17, 19, 14, 6, 1, 10, 4, 18, 7, 2, 13, 15, 9, 11, 16] 94 420
[0] 0 0
[5, 1, 6, 2, 4, 3, 0] 14 42
[0] 0 0
[1, 7, 2, 0, 3, 5, 6, 4, 9, 8] 11 90
[15, 13, 1, 8, 10, 7, 11, 5, 9, 14, 6, 3, 0, 12, 2, 4] 80 240
[1, 6, 0, 4, 5, 2, 3] 10 42
[0, 11, 10, 6, 13, 8, 1, 3, 7, 9, 2, 5, 4, 12, 14] 47 210
[6, 18, 3, 12, 15, 9, 2, 1, 10, 8, 7, 5, 13, 11, 14, 4, 16, 17, 0, 19] 82 380
[9, 7, 3, 5, 10, 4, 8, 6, 0, 2, 1] 40 110
[5, 7, 0, 6, 2, 1, 3, 4] 16 56
[3, 1, 2, 4, 6, 0, 5] 8 42
[10, 7, 0, 8, 2, 3, 6, 9, 5, 1, 11, 4] 34 132
[29, 26, 24, 18, 27, 3, 16, 12, 23, 11, 20, 21, 8, 6, 19, 2, 7, 13, 14, 15, 30, 22, 25, 1, 4, 17, 5, 10, 28, 9, 0] 293 930
[6, 28, 30, 3, 1, 16, 19, 11, 15, 14, 10, 32, 24, 22, 26, 31, 21, 20, 8, 18, 33, 27, 4, 5, 0, 9, 29, 12, 25, 7, 2, 23, 13, 17] 

0.05785017957927141

In [None]:
#if e <= 2:
    #y_val, y_pred = validate(model, val_loader)
    #val_df["pred"] = val_df.groupby(["id", "cell_type"])["rank"].rank(pct=True)
    #val_df.loc[val_df["cell_type"] == "markdown", "pred"] = y_pred
    #y_dummy = val_df.sort_values("pred").groupby('id')['cell_id'].apply(list)
    #print("Preds score", kendall_tau(df_orders.loc[y_dummy.index], y_dummy))
    #continue
#print("epoch:", e)

avg_loss = 0
for idx, data in enumerate(tbar):
    # print(type(data), type(data[0]), data[0].size())
    # print("data[3].size()[1]: ",data[3].size()[1])
    
    data = read_data(data)
    
    #with torch.cuda.amp.autocast():
    loss = model(*data)
        # loss = criterion(pred, target)
    scaler.scale(loss).backward()
    nn.utils.clip_grad_norm_(model.parameters(), max_norm=1)
    if idx % accumulation_steps == 0 or idx == len(tbar) - 1:
        scaler.step(optimizer)
        scaler.update()
        optimizer.zero_grad()
        scheduler.step()
        
    loss_list.append(loss.detach().cpu().item())
    if idx == 0:
        continue

    if idx % 100 == 0:
        avg_loss = np.round(np.mean(loss_list[-1000:]), 4)
        # print("memory allocated:",torch.cuda.memory_allocated(device=0) / (1024 * 1024))
        print("avg_loss:", avg_loss)
    if idx % 5000 == 0:
        torch.save(model.state_dict(), f"./outputs/model_{idx}.bin")

    
    # preds.append(pred.detach().cpu().numpy().ravel())
    # labels.appendabs(target.detach().cpu().numpy().ravel())



    

In [None]:
import matplotlib.pyplot as plt 
plt.title('loss')
plt.plot(loss_list, label='train')
#plt.plot(history.history['val_loss'], label='test')
plt.legend()
plt.show()



In [None]:
with torch.cuda.amp.autocast():
        pred = model(*inputs)
        loss = criterion(pred, target)
    scaler.scale(loss).backward()
    if idx % args.accumulation_steps == 0 or idx == len(tbar) - 1:
        scaler.step(optimizer)
        scaler.update()
        optimizer.zero_grad()
        scheduler.step()

    if idx % 10000 == 0:
        torch.save(model.state_dict(), f"./outputs/model_{idx}.bin")

    loss_list.append(loss.detach().cpu().item())
    preds.append(pred.detach().cpu().numpy().ravel())
    labels.append(target.detach().cpu().numpy().ravel())

    avg_loss = np.round(np.mean(loss_list), 4)

    tbar.set_description(f"Epoch {e + 1} Loss: {avg_loss} lr: {scheduler.get_last_lr()}")



y_val, y_pred = validate(model, val_loader)
val_df["pred"] = val_df.groupby(["id", "cell_type"])["rank"].rank(pct=True)
val_df.loc[val_df["cell_type"] == "markdown", "pred"] = y_pred
y_dummy = val_df.sort_values("pred").groupby('id')['cell_id'].apply(list)
print("Preds score", kendall_tau(df_orders.loc[y_dummy.index], y_dummy))
torch.save(model.state_dict(), f"./outputs/model_epoch_{e+1}.bin")

In [None]:
for e in range(epochs):
    model.train()
    tbar = tqdm(train_loader, file=sys.stdout)
    loss_list = []
    preds = []
    labels = []

    #if e <= 2:
        #y_val, y_pred = validate(model, val_loader)
        #val_df["pred"] = val_df.groupby(["id", "cell_type"])["rank"].rank(pct=True)
        #val_df.loc[val_df["cell_type"] == "markdown", "pred"] = y_pred
        #y_dummy = val_df.sort_values("pred").groupby('id')['cell_id'].apply(list)
        #print("Preds score", kendall_tau(df_orders.loc[y_dummy.index], y_dummy))
        #continue
    #print("epoch:", e)

    for idx, data in enumerate(tbar):
        inputs = read_data(data)
        target = 1

        time.sleep(3)
        with torch.cuda.amp.autocast():
            pred = model(*inputs)
            loss = criterion(pred, target)
        scaler.scale(loss).backward()
        if idx % args.accumulation_steps == 0 or idx == len(tbar) - 1:
            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad()
            scheduler.step()

        if idx % 2000 == 0:
            torch.save(model.state_dict(), f"./outputs/model_{idx}.bin")

        loss_list.append(loss.detach().cpu().item())
        preds.append(pred.detach().cpu().numpy().ravel())
        labels.append(target.detach().cpu().numpy().ravel())

        avg_loss = np.round(np.mean(loss_list), 4)

        tbar.set_description(f"Epoch {e + 1} Loss: {avg_loss} lr: {scheduler.get_last_lr()}")

    y_val, y_pred = validate(model, val_loader)
    val_df["pred"] = val_df.groupby(["id", "cell_type"])["rank"].rank(pct=True)
    val_df.loc[val_df["cell_type"] == "markdown", "pred"] = y_pred
    y_dummy = val_df.sort_values("pred").groupby('id')['cell_id'].apply(list)
    print("Preds score", kendall_tau(df_orders.loc[y_dummy.index], y_dummy))
    torch.save(model.state_dict(), f"./outputs/model_epoch_{e+1}.bin")

return model, y_pred

In [None]:
# model.load_state_dict(torch.load("./outputs/model.bin"))
model, y_pred = train(model, train_loader, val_loader, epochs=epochs)


In [None]:
train_df.loc['3a16a457']

In [None]:
s = time.time()
print(torch.cuda.memory_allocated(device=0) / (1024 * 1024))
print(time.time()-s)