In [8]:
import os
import torch
import math
from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader
from torch.utils.tensorboard import SummaryWriter
from dataloader import dataset
from model import FFD, TRT

In [79]:
def overlap_coefficient(interval_A, interval_B):
    start_A, end_A = interval_A
    start_B, end_B = interval_B
    end_B = start_B + end_B
    
    intersection_start = torch.max(start_A, start_B)
    intersection_end = torch.min(end_A, end_B)
    
    intersection_length = torch.clamp(intersection_end - intersection_start, min=0)
    
    length_A = end_A - start_A
    length_B = end_B - start_B

    if torch.min(length_A, length_B) == 0:
        return 0.1

    return intersection_length / torch.min(length_A, length_B)

In [80]:
def loss_fn(truth, ReP):
    ReP = ReP[0]
    loss = torch.tensor([0]).to(device)
    for i in range(truth.shape[0]):
        loss = loss + 1 - overlap_coefficient(truth[i],ReP[i])
    
    loss = loss / truth.shape[0]
    return loss

In [97]:
def train(FFD, TRT, trainloader, epochs, optimizerF, optimizerT, loss_fn=loss_fn):
    writer = SummaryWriter("./log/")
    for epoch in range(epochs):
        loop = tqdm(enumerate(trainloader, start=epoch * len(trainloader)), total=len(trainloader), leave=False)
        lf = 0
        lt = 0
        best_lf = 1000
        best_lt = 1000
        for step, (sentence,FFDAvg, FFDStd,TRTAvg,TRTStd) in loop:
            sentence = [w[0] for w in sentence]
            True_FFD = torch.cat((FFDAvg-3*FFDStd,FFDAvg+3*FFDStd)).T
            True_TRT = torch.cat((TRTAvg-3*TRTStd,TRTAvg+3*TRTStd)).T

            optimizerF.param_groups[0]['lr'] = 0.000001#adjust_learning_rate(epochs, batch_size, trainloader, step)
            optimizerF.zero_grad()

            optimizerT.param_groups[0]['lr'] = 0.000001#adjust_learning_rate(epochs, batch_size, trainloader, step)
            optimizerT.zero_grad()

            FFD_I = FFD(sentence)
            TRT_I = TRT(sentence)

            lossF = loss_fn(True_FFD, FFD_I)
            lossF.backward()
            optimizerF.step()

            lossT = loss_fn(True_TRT, TRT_I)
            lossT.backward()
            optimizerT.step()

            lf = lf+lossF
            lt = lt+lossT

            writer.add_scalar("LossF/LossT/train", lossF, lossT, epoch)

            loop.set_description(f'Epoch [{epoch}/{epochs}]')
            loop.set_postfix(loss = lossF.cpu().detach().numpy())
            loop.set_postfix(loss = lossT.cpu().detach().numpy())
            
        print(f'Loss for epoch {epoch} is {lf.cpu().detach().numpy()} and {lt.cpu().detach().numpy()}')
        if best_lf>lf:
            best_lf = lf
            state = dict(epoch=epochs, model=FFD.state_dict(),
                 optimizer=optimizerF.state_dict())
            torch.save(state, os.path.join('.', 'checkpoints',  'best_checkpointF'+str(lf.item())+'.pth'))
        if best_lt>lt:
            best_lt = lt
            state = dict(epoch=epochs, model=TRT.state_dict(),
                 optimizer=optimizerF.state_dict())
            torch.save(state, os.path.join('.', 'checkpoints',  'best_checkpointT'+str(lt.item())+'.pth'))
        
    print('End of the Training. Saving final checkpoints.')

    state = dict(epoch=epochs, model=FFD.state_dict(),
                 optimizer=optimizerF.state_dict())
    torch.save(state, os.path.join('.', 'checkpoints',  'final_checkpointF.pth'))

    state = dict(epoch=epochs, model=TRT.state_dict(),
                 optimizer=optimizerT.state_dict())
    torch.save(state, os.path.join('.', 'checkpoints',  'final_checkpointT.pth'))

    writer.flush()
    writer.close()    
                
                

In [75]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda', index=0)

In [76]:
Training_set = dataset(file_path='c:\\Users\\ludandan\\Desktop\\CCS3\\dataset\\Training\\train.csv', language='en')
trainingloader = DataLoader(dataset=Training_set,batch_size=1,shuffle=True)

In [89]:
ffd = FFD(padding=False).to(device)
optimizerF = torch.optim.Adam(filter(lambda p: p.requires_grad, ffd.parameters()),
                lr=0.1,
                betas=(0.9, 0.999),
                eps=1e-08,
                weight_decay=0,
                amsgrad=False)


trt = TRT(padding=False).to(device)
optimizerT = torch.optim.Adam(filter(lambda p: p.requires_grad, trt.parameters()),
                lr=0.1,
                betas=(0.9, 0.999),
                eps=1e-08,
                weight_decay=0,
                amsgrad=False)

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.decoder.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.seq_relationship.weight', 'cls.predictions.trans

In [98]:
epochs = 100
lr = 0.1
batch_size = 1

train(ffd,trt,trainingloader,epochs,optimizerF,optimizerT)

                                                                                   

Loss for epoch 0 is [83.72135406] and [77.10205421]


                                                                                       

Loss for epoch 1 is [78.30500134] and [72.67843645]


                                                                                   

Loss for epoch 2 is [74.79208116] and [67.88529536]


                                                                                       

Loss for epoch 3 is [71.78012498] and [63.16149182]


                                                                                   

Loss for epoch 4 is [69.05759813] and [57.35326546]


                                                                                   

Loss for epoch 5 is [66.21079911] and [49.94028188]


                                                                                   

Loss for epoch 6 is [63.40411284] and [42.72734683]


Epoch [7/100]:   3%|â–Ž         | 21/626 [00:07<03:56,  2.56it/s, loss=[0.09398893]]