In [1]:
import os.path as osp
from tqdm.auto import tqdm
import numpy as np
import wandb

import torch
from PIL import Image
import torchvision.transforms as T
from torchvision.utils import save_image

import torch.nn.functional as F
# from ogb.nodeproppred import Evaluator, PygNodePropPredDataset
from torch.nn import LayerNorm, Linear, ReLU, Tanh

import torch_geometric.transforms as T
from torch_geometric.datasets import Planetoid
from torch_geometric.nn import GCNConv, GatedGraphConv, GCN
from torch_geometric.nn import DeepGCNLayer, GENConv, MessageNorm
from torch_geometric.utils import negative_sampling, to_dense_adj
from torch_geometric.loader import DataLoader

from dataset_processing import RNADataset


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
dataset = RNADataset(root="./data/")
dataset = dataset.shuffle()
train_data, val_data, test_data = dataset[0:655], dataset[655:873], dataset[873:]
bs = 4

train_dataloader = DataLoader(train_data, batch_size=bs, shuffle=False)
val_dataloader = DataLoader(val_data, batch_size=1, shuffle=False)


In [3]:
def del_nucl_conn(y):
#     main_diag = torch.zeros([1, y.size()[0] // bs])
#     diag1 = torch.ones([1, (y.size()[0] // bs) - 1])
    
    main_diag = torch.zeros([1, y.size()[0]])
    diag1 = torch.ones([1, y.size()[0] - 1])
#     batch_diags = []
#     for i in range(bs):
#         batch_diags.append((torch.diag_embed(main_diag).to(device) +
#             torch.diag_embed(diag1, offset=1).to(device) + 
#             torch.diag_embed(diag1, offset=-1).to(device) - 1)[0])

#     return y * torch.cat(batch_diags, 0).to(device) * (-1)
    return y * (torch.diag_embed(main_diag).to(device) + 
                    torch.diag_embed(diag1, offset=1).to(device) + 
                torch.diag_embed(diag1, offset=-1).to(device) - 1)[0] * (-1)


def del_main_diag(y):
#     main_diag = torch.ones([1, y.size()[0] // bs])
    main_diag = torch.ones([1, y.size()[0]])

#     batch_diags = []
#     for i in range(bs):
#         batch_diags.append((torch.diag_embed(main_diag).to(device) - 1)[0])
    
#     return y * torch.cat(batch_diags, 0).to(device) * (-1)
    return y * (torch.diag_embed(main_diag).to(device) - 1)[0] * (-1)


def adj_mat_split(y):
    y = list(torch.split(y, 196, dim=0))
#     print(y)
    for i in range(len(y)):
        y[i] = torch.split(y[i], 196, dim=1)
    
    y1 = []
    for i in range(len(y)):
        y1.append(y[i][i])
    
    return torch.cat(y1, 0).to(device)
    

def argm(y):
    max_ind = torch.argmax(y, dim=1)
    y1 = torch.zeros_like(y)
    k = 0
    for i in max_ind:
        y1[k][i] = 1.
        y1[i][k] = 1.
        k += 1
    
    return y1 * y


In [4]:
a = 0.8

def precision(y_pred, y_true, y_in):
    y_pred = y_pred + y_in
    y_pred = del_nucl_conn(y_pred)    
    y_pred[(y_pred > a)] = 1
    y_pred[(y_pred <= a)] = 0 
#     y_pred = argm(y_pred)
    y_pred = del_main_diag(y_pred)
    y_pred = adj_mat_split(y_pred)

    tp = torch.sum(y_pred * y_true)
    fp = torch.sum((1 - y_true) * y_pred)
    
    return tp / (tp + fp + epsilon)

def recall(y_pred, y_true, y_in):
    y_pred = y_pred + y_in
    y_pred = del_nucl_conn(y_pred)    
    y_pred[(y_pred > a)] = 1
    y_pred[(y_pred <= a)] = 0 
#     y_pred = argm(y_pred)
    y_pred = del_main_diag(y_pred)
    y_pred = adj_mat_split(y_pred)

    tp = torch.sum(y_pred * y_true)
    fn = torch.sum(y_true * (1 - y_pred))
    
    return tp / (tp + fn + epsilon)

def f_loss(y_pred, y_true, y_in):
    y_pred = y_pred + y_in
    y_pred = del_nucl_conn(y_pred)
#     y_pred = argm(y_pred)
    y_pred = del_main_diag(y_pred)
    y_pred = adj_mat_split(y_pred)

    tp = torch.sum(y_pred * y_true)
    fn = torch.sum(y_true * (1 - y_pred))
    fp = torch.sum((1 - y_true) * y_pred)
    precision = tp / (tp + fp + epsilon)
    recall = tp / (tp + fn + epsilon)

    k = 1
    #k1 = 1 - torch.abs(precision - recall)

#     k2 = 1 - torch.abs(K.mean(precision) - K.mean(recall))
    #calculate upgraded f1 score
    f1 = ((1 + k * k) * precision * recall) / (k * k * precision + recall + epsilon)
#     tw = K.sum(K.cast(y_true * y_pred, ’float32’), axis=[1, 2, 3])
#     fw = K.sum(K.cast((1 - y_true) * y_pred, ’float32’), axis=[1, 2, 3])
#     fb = K.sum(K.cast(y_true * (1 - y_pred), ’float32’), axis=[1, 2, 3])
    return 1 - f1


def mse_loss(y_pred, y_true):
#    y_pred = torch.tanh((y_pred*3)**5)
#     y_true = y_true - y_in
    y_pred = del_nucl_conn(y_pred)
#     y_pred[(y_pred > a)] = 1
#     y_pred[(y_pred <= a)] = 0 
#     y_pred = argm(y_pred)
    y_pred = del_main_diag(y_pred)
    y_true = del_nucl_conn(y_true)
    y_true = del_main_diag(y_true)
    
#     y_pred = adj_mat_split(y_pred)
#     return torch.sqrt(torch.mean((y_pred - y_true) ** 2))
    return (torch.sum((y_pred - y_true) ** 2))


def distance_loss(y_pred, y_true):
   # print(">>>>>>>>>>>>>>>>>>>>>>>>>>")
    #y_pred = torch.tanh((y_pred*3)**5)
    y_pred = del_nucl_conn(y_pred)
    y_pred = del_main_diag(y_pred)
    y_true = del_nucl_conn(y_true)
    y_true = del_main_diag(y_true)
    
    
    
   # print(y_pred)
    #print(y_true)
    
    #print ("---------------------------------")
    
    
    c = 0.001
    
#     y_pred[0][0] = 0.01
#     y_true[0][0] = 0.01
#     y_pred[(torch.abs(y_pred) < c)] = 0
#    c = torch.rand(y_true.size()).to(device)
#     k = 0.4
#     c[(c > k)] = 0
#     c += y_true
#     y_true += c
    #y_pred.where(y_pred > c or y_pred < -c, y_pred, 0)
    
    #print(y_pred)
    
    #print("<<<<<<<<<<<<<<<<<<<<<<<<<")
    
    #c = 0.0001
    y_true_local = y_true + c
    y_true_local = torch.triu(y_true_local, diagonal=2)
    y_true_local = y_true_local.to_sparse()
    #c_sparse = torch.triu(c, diagonal=2).to_sparse()
    y_true_local = torch.cat([y_true_local.indices(), torch.unsqueeze((y_true_local.values() - c) , 0)], dim=0)
#     y_true_local = torch.cat([y_true_local.indices(), torch.unsqueeze(y_true_local.values() * 100, 0)], dim=0)
    
    y_pred_local = torch.triu(y_pred, diagonal=2)
    y_pred_local = y_pred_local.to_sparse()
    y_pred_local = torch.cat([y_pred_local.indices(), torch.unsqueeze(y_pred_local.values(), 0)], dim=0)
    dist2 = torch.cdist(y_true_local.t(), y_pred_local.t())
    x_min,ind1 = torch.min(dist2, dim=0)
    y_min,ind1 = torch.min(dist2, dim=1)
    return torch.sum(torch.sqrt(x_min))/(x_min.size()[0]) + torch.sum(torch.sqrt(y_min))/(y_min.size()[0])
#     return torch.sum(x_min ** 2)/(x_min.size()[0]) + torch.sum(y_min ** 2)/(y_min.size()[0])

    #return torch.sum(x_min)/(x_min.size()[0])
    


In [50]:
for layer in model.layers:
    layer.dropout = 0.2


In [40]:
model.layers[0].dropout

0.2

In [7]:
class DeepGCNModel(torch.nn.Module):
    def __init__(self, num_layers, out_channels):
        super(DeepGCNModel, self).__init__()
        self.layers = torch.nn.ModuleList()
#         conv1 = GCN(dataset.num_features, out_channels, num_layers=2)
#         layer = DeepGCNLayer(conv1, block='res', dropout=0)

#         self.layers.append(layer)
        for i in range(1, num_layers + 1):

            conv = GatedGraphConv(out_channels, num_layers=6)

            layer = DeepGCNLayer(conv, block='res', dropout=0)
            self.layers.append(layer)

        self.lin = Linear(out_channels, 1)

    def forward(self, x, edge_index):
        x = self.layers[0].conv(x, edge_index)
        for layer in self.layers[1:]:
            x = layer(x, edge_index)
            
        x1, x2 = torch.split(x, out_channels // 2, dim=1)
        #!!return torch.tanh(((torch.tanh(x1 @ x2.t()))*20)**11)
        return torch.tanh(((torch.tanh(x1 @ x2.t())))**11)
#         return (x1 @ x2.t()).sigmoid()

        
#         return torch.tanh(((torch.tanh(x1 @ x2.t()))*1000))
        #return torch.tanh(torch.tanh(x @ x.t())**11)

In [49]:
num_layers = 3
# num_features = dataset.num_features
out_channels = 200


lr = 0.0000001
epochs = 175
beta = 1


# model = DeepGCNModel(num_layers, out_channels)
model = torch.load("./models1/ResGatedGCN_dist(sqrt)_diff_mult_dist_01_dropout_0_3*6_1e-05_RMSProp_200_0.8_175.pt")

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)
train_data = list(map(lambda x: x.to(device), train_data))
val_data = list(map(lambda x: x.to(device), val_data))
print(device)

# optimizer = torch.optim.Adam(model.parameters(), lr=lr)
#optimizer = torch.optim.Adagrad(model.parameters(), lr=lr)
optimizer = torch.optim.RMSprop(model.parameters(), lr=lr, momentum=0)
#optimizer = torch.optim.Adamax(model.parameters(), lr=lr)
# optimizer = torch.optim.Adadelta(model.parameters(), lr=lr)

# optimizer = torch.optim.SGD(model.parameters(), lr=0.001)

# criterion = torch.nn.BCELoss()
# criterion = torch.nn.BCEWithLogitsLoss()
# criterion = torch.nn.CrossEntropyLoss()

# def RMSELoss(y_pred, y_true):
#     return torch.sqrt(torch.mean((y_pred - y_true) ** 2))

# criterion = f_loss
# criterion = mse_loss
criterion = distance_loss

cuda


In [51]:
run_name = "ResGatedGCN_dist(sqrt)_diff_mult_dist_01_dropout(0->0.2)" \
+ str(num_layers) + "*6_" + str(lr) + "_" + "RMSProp_" + str(out_channels) + "_" + str(a)


wandb.init(
    # set the wandb project where this run will be logged
    project="secondary_structure_prediction1",
    
    # track hyperparameters and run metadata
    config={
    "learning_rate": lr,
    "architecture": "ResGatedGCN",
    "epochs": epochs,
    "optimizer": "RMSProp",
    "out_channels": out_channels,
    "loss": "dist_loss",
    "beta": beta,
    "train:val:test": "655:218:218"
    },
    name=run_name
)

epsilon = 1e-10

0,1
train_f1,▆█▇█▂▅▄▄▆▆▃▇▄▆▄▄▁▅▆▃▃▃▄▃▅▆
train_loss,██▆▆▆▅▄▄▃▄▄▃▃▂▃▄▂▃▁▁▂▂▃▂▁▃
train_precision,▃▄▄▅▂▃▄▃▅▆▁█▂▇▃▄▃▁▆▂▂▃▂▂▅▂
train_recall,▆█▇▇▂▆▄▄▅▅▅▄▅▅▄▄▁▆▅▄▃▃▅▄▄▇
val_f1,▇▆▆▅▄▄▁▇▄▇▄▅▃▆▆▄▆█▂▄▅▇▇▄▇▇
val_loss,▂▂▃▄█▃▁▇▂▂▃▄▃▂▃▁▆▂▄▃▃▃▂▂▃▂
val_precision,▄▄▄▄▃▃▁▆▅█▄▃▃▃▃▂▃▅▂▂▃▄▅▃▅▄
val_recall,█▆▆▄▄▄▁▇▂▄▃▆▃▇▇▄▆█▂▄▅█▆▄▆█

0,1
train_f1,0.37874
train_loss,0.09558
train_precision,0.26108
train_recall,0.68947
val_f1,0.36297
val_loss,0.08356
val_precision,0.24875
val_recall,0.67118


In [52]:
def train():
    for epoch in range(1, epochs + 1):
        model.train()
        train_loss = []
        train_recall = []
        train_precision = []  
        k = 0
        for g in tqdm(train_data):
            g.to(device)
            optimizer.zero_grad()

            out = model(g.x, g.edge_index)
            y_true = g.adj_mat
#             print(out.size(),y_true.size())
            loss = criterion(out, y_true.to(torch.float32) - torch.squeeze(to_dense_adj(g.edge_index)))          
#             loss = criterion(out, y_true.to(torch.float32))          

            loss.backward()
            optimizer.step()
            
#             out = out.sigmoid()
#             if k <= 2:
#                 save_image(y_true, f"./pictures/{k}/{epoch}_t.jpg")
#                 save_image(torch.squeeze(to_dense_adj(g.edge_index)), f"./pictures/{k}/{epoch}_p.jpg")
#                 save_image(out + torch.squeeze(to_dense_adj(g.edge_index)), f"./pictures/{k}/{epoch}_nn.jpg")
#             k += 1
            
            train_loss.append(loss.item())
            train_precision.append(precision(out, y_true, torch.squeeze(to_dense_adj(g.edge_index))).item())
            train_recall.append(recall(out, y_true, torch.squeeze(to_dense_adj(g.edge_index))).item())
            
        train_prec = np.mean(train_precision)
        train_rec = np.mean(train_recall)
        train_f1 = (2 * train_prec * train_rec) / (train_prec + train_rec)
        print(f'Epoch: {epoch:03d}, loss: {np.mean(train_loss)}, f1: {train_f1}, precision: {train_prec}, recall: {train_rec}')
        
        val_loss = []
        val_recall = []
        val_precision = []
        with torch.no_grad():
            for g in val_data:
                g.to(device)
                out = model(g.x, g.edge_index)
        
                y_true = g.adj_mat
                loss = criterion(out, y_true.to(torch.float32) - torch.squeeze(to_dense_adj(g.edge_index)))
#                 loss = criterion(out, y_true.to(torch.float32))          

                                
                val_loss.append(loss.item())
                
#                 out = out.sigmoid()
                
                val_precision.append(precision(out, y_true, torch.squeeze(to_dense_adj(g.edge_index))).item())
                val_recall.append(recall(out, y_true, torch.squeeze(to_dense_adj(g.edge_index))).item())
 
            
            prec = np.mean(val_precision)
            rec = np.mean(val_recall)
            f1 = (2 * prec * rec) / (prec + rec)
            print(f'val_loss: {np.mean(val_loss)}, val_f1: {f1}, val_precision: {prec}, val_recall: {rec}')
              
            wandb.log({"train_loss": np.mean(train_loss), "train_f1": train_f1, "train_precision": train_prec, 
                       "train_recall": train_rec,
                       "val_loss": np.mean(val_loss), "val_f1": f1, "val_precision": prec, "val_recall": rec})
#     wandb.finish()

In [53]:
train()
torch.save(model, "./models1/" + run_name + "_" + str(epochs) + ".pt")

100%|█████████████████████████████████████████| 655/655 [00:15<00:00, 41.78it/s]


Epoch: 001, loss: 0.09633414049753707, f1: 0.3786539215497728, precision: 0.2611049390191795, recall: 0.6887099045833559
val_loss: 0.08384594871938111, val_f1: 0.3629629895911332, val_precision: 0.24875327622637564, val_recall: 0.6710706249289556


100%|█████████████████████████████████████████| 655/655 [00:16<00:00, 40.49it/s]


Epoch: 002, loss: 0.09649215990122947, f1: 0.37866337200709, precision: 0.26108278702546395, recall: 0.6889266303932394
val_loss: 0.0839434014361649, val_f1: 0.3629487892558645, val_precision: 0.24874593960890257, val_recall: 0.6710269374584933


100%|█████████████████████████████████████████| 655/655 [00:15<00:00, 41.05it/s]


Epoch: 003, loss: 0.09629128804956456, f1: 0.37868800815823256, precision: 0.26111075026445263, recall: 0.6888950327425513
val_loss: 0.08378758321395142, val_f1: 0.3629441035302213, val_precision: 0.24874303861598082, val_recall: 0.6710160158642935


100%|█████████████████████████████████████████| 655/655 [00:15<00:00, 41.66it/s]


Epoch: 004, loss: 0.09607959880870376, f1: 0.37875428310405124, precision: 0.2611309675229642, recall: 0.6891930225468774
val_loss: 0.08371032044401222, val_f1: 0.3629739243484046, val_precision: 0.24874854334758237, val_recall: 0.6711798427848641


100%|█████████████████████████████████████████| 655/655 [00:16<00:00, 40.55it/s]


Epoch: 005, loss: 0.09616659698396467, f1: 0.3787573783725522, precision: 0.2611125054989607, recall: 0.6893421619448044
val_loss: 0.08369029960382758, val_f1: 0.36293175709260106, val_precision: 0.2487299398459289, val_recall: 0.6710269374584933


100%|█████████████████████████████████████████| 655/655 [00:16<00:00, 40.51it/s]


Epoch: 006, loss: 0.09615009222535117, f1: 0.3787107532108732, precision: 0.2611043115409503, recall: 0.6890904414062281
val_loss: 0.08373381538313558, val_f1: 0.3629737441107047, val_precision: 0.24874837405193562, val_recall: 0.6711798427848641


100%|█████████████████████████████████████████| 655/655 [00:16<00:00, 40.59it/s]


Epoch: 007, loss: 0.09614645611949776, f1: 0.3787164669212678, precision: 0.26109288191806723, recall: 0.6892079065319237
val_loss: 0.0838839025578436, val_f1: 0.362968308085711, val_precision: 0.2487537703741718, val_recall: 0.6711033902583866


100%|█████████████████████████████████████████| 655/655 [00:16<00:00, 40.57it/s]


Epoch: 008, loss: 0.09583144159774289, f1: 0.37869378108804164, precision: 0.26112836888842, recall: 0.6888106210996177
val_loss: 0.08387489062612945, val_f1: 0.3629683697511433, val_precision: 0.2487433259930769, val_recall: 0.6711798427848641


100%|█████████████████████████████████████████| 655/655 [00:16<00:00, 40.78it/s]


Epoch: 009, loss: 0.0959519486127862, f1: 0.3787156267332543, precision: 0.2611069024519156, recall: 0.6891046664642013
val_loss: 0.08360644686883197, val_f1: 0.3629546042272459, val_precision: 0.2487408977018994, val_recall: 0.6711033902583866


100%|█████████████████████████████████████████| 655/655 [00:15<00:00, 41.67it/s]


Epoch: 010, loss: 0.09580882120991253, f1: 0.3786876685884575, precision: 0.2611141934905571, recall: 0.6888688190747764
val_loss: 0.0836955375980127, val_f1: 0.36297454658366857, val_precision: 0.24874912780803551, val_recall: 0.6711798427848641


100%|█████████████████████████████████████████| 655/655 [00:15<00:00, 40.94it/s]


Epoch: 011, loss: 0.0957922079948751, f1: 0.37874838283998546, precision: 0.26112582322693506, recall: 0.6891897840809276
val_loss: 0.08372021927982091, val_f1: 0.36293958081120664, val_precision: 0.24873728926267918, val_recall: 0.6710269374584933


100%|█████████████████████████████████████████| 655/655 [00:15<00:00, 41.01it/s]


Epoch: 012, loss: 0.09593781828311564, f1: 0.37870831651943115, precision: 0.261096953356084, recall: 0.6891255597235593
val_loss: 0.08349038272195974, val_f1: 0.3629454859005032, val_precision: 0.24874283644208395, val_recall: 0.6710269374584933


100%|█████████████████████████████████████████| 655/655 [00:15<00:00, 41.18it/s]


Epoch: 013, loss: 0.09583585253554106, f1: 0.37871185060500195, precision: 0.26110763547529703, recall: 0.6890745572461426
val_loss: 0.08383572372969549, val_f1: 0.36297242209871444, val_precision: 0.24874713229989515, val_recall: 0.6711798427848641


100%|█████████████████████████████████████████| 655/655 [00:15<00:00, 40.97it/s]


Epoch: 014, loss: 0.09603370414008848, f1: 0.37871841574065424, precision: 0.2610885573835432, recall: 0.6892509516072637
val_loss: 0.08359018750033143, val_f1: 0.36295611859345067, val_precision: 0.24874232019887332, val_recall: 0.6711033902583866


100%|█████████████████████████████████████████| 655/655 [00:16<00:00, 40.18it/s]


Epoch: 015, loss: 0.09593077875773301, f1: 0.37867653223127234, precision: 0.26109304081733903, recall: 0.688942357476886
val_loss: 0.08389181329723482, val_f1: 0.3629553076077515, val_precision: 0.24874155841118425, val_recall: 0.6711033902583866


100%|█████████████████████████████████████████| 655/655 [00:16<00:00, 40.27it/s]


Epoch: 016, loss: 0.09578109357634004, f1: 0.3787112152738673, precision: 0.26112304757000837, recall: 0.6889630365348954
val_loss: 0.08365362003635354, val_f1: 0.36294142004364416, val_precision: 0.24873901701130724, val_recall: 0.6710269374584933


100%|█████████████████████████████████████████| 655/655 [00:16<00:00, 40.07it/s]


Epoch: 017, loss: 0.09576699108610513, f1: 0.37868608875748383, precision: 0.261079306647409, recall: 0.6891012882685843
val_loss: 0.0835830290213106, val_f1: 0.3629744182989828, val_precision: 0.24875951009811065, val_recall: 0.6711033902583866


100%|█████████████████████████████████████████| 655/655 [00:16<00:00, 40.23it/s]


Epoch: 018, loss: 0.09576455642405243, f1: 0.3787677494728399, precision: 0.26114656688318455, recall: 0.6891735417242268
val_loss: 0.08371773112411181, val_f1: 0.3629651696641064, val_precision: 0.24875082227683395, val_recall: 0.6711033902583866


100%|█████████████████████████████████████████| 655/655 [00:16<00:00, 40.60it/s]


Epoch: 019, loss: 0.09568900454661432, f1: 0.3785739779956454, precision: 0.2610752864065402, recall: 0.6883873367866942
val_loss: 0.0836280309431942, val_f1: 0.36296836795437376, val_precision: 0.2487538266123845, val_recall: 0.6711033902583866


100%|█████████████████████████████████████████| 655/655 [00:16<00:00, 40.36it/s]


Epoch: 020, loss: 0.09570263281663638, f1: 0.3786840288209342, precision: 0.2611276429274496, recall: 0.6887511463167558
val_loss: 0.08391098271214321, val_f1: 0.3629594221196915, val_precision: 0.24874542333151495, val_recall: 0.6711033902583866


100%|█████████████████████████████████████████| 655/655 [00:16<00:00, 40.39it/s]


Epoch: 021, loss: 0.09605778649819034, f1: 0.37877906824806307, precision: 0.2611220666413544, recall: 0.6894192190798184
val_loss: 0.08368490777294081, val_f1: 0.3629679614053269, val_precision: 0.24874294244197256, val_recall: 0.6711798427848641


100%|█████████████████████████████████████████| 655/655 [00:16<00:00, 40.32it/s]


Epoch: 022, loss: 0.09577727493055108, f1: 0.37871102545090424, precision: 0.2611115400800268, recall: 0.689041901749509
val_loss: 0.08361668564815744, val_f1: 0.3629391123703018, val_precision: 0.24873684921703482, val_recall: 0.6710269374584933


100%|█████████████████████████████████████████| 655/655 [00:16<00:00, 40.38it/s]


Epoch: 023, loss: 0.09573631664339936, f1: 0.37864625600638674, precision: 0.26110055091442497, recall: 0.688689716313859
val_loss: 0.08367450682950939, val_f1: 0.3629742523491924, val_precision: 0.24874885143592543, val_recall: 0.6711798427848641


100%|█████████████████████████████████████████| 655/655 [00:16<00:00, 40.46it/s]


Epoch: 024, loss: 0.09581223063327315, f1: 0.37872657854689185, precision: 0.2611222375952105, recall: 0.6890703803482856
val_loss: 0.08370718433124814, val_f1: 0.36294693005022777, val_precision: 0.24874419307148238, val_recall: 0.6710269374584933


100%|█████████████████████████████████████████| 655/655 [00:16<00:00, 40.21it/s]


Epoch: 025, loss: 0.09592581769408844, f1: 0.3787058206518965, precision: 0.2610829592462032, recall: 0.6892065307907476
val_loss: 0.08349782059606929, val_f1: 0.3629335903876569, val_precision: 0.24873166198953303, val_recall: 0.6710269374584933


100%|█████████████████████████████████████████| 655/655 [00:16<00:00, 40.80it/s]


Epoch: 026, loss: 0.09554983974356232, f1: 0.37869929970089183, precision: 0.2611262776952891, recall: 0.688861691235131
val_loss: 0.08379886042089227, val_f1: 0.3629431333053824, val_precision: 0.24874062643070285, val_recall: 0.6710269374584933


100%|█████████████████████████████████████████| 655/655 [00:16<00:00, 40.16it/s]


Epoch: 027, loss: 0.09544023374977116, f1: 0.37860336762344543, precision: 0.26107773562992803, recall: 0.6885646912660307
val_loss: 0.0836693260508571, val_f1: 0.36297157155623716, val_precision: 0.248746333396011, val_recall: 0.6711798427848641


100%|█████████████████████████████████████████| 655/655 [00:16<00:00, 40.38it/s]


Epoch: 028, loss: 0.09563763056035715, f1: 0.37866183238050855, precision: 0.26109837716345796, recall: 0.6888079123016987
val_loss: 0.08377063028675853, val_f1: 0.36296172150575295, val_precision: 0.24874758324799462, val_recall: 0.6711033902583866


100%|█████████████████████████████████████████| 655/655 [00:15<00:00, 42.09it/s]


Epoch: 029, loss: 0.09560485698766154, f1: 0.37867326264513873, precision: 0.26109713525262496, recall: 0.6888922085175078
val_loss: 0.08371034072781697, val_f1: 0.36296247418504723, val_precision: 0.24874829027562514, val_recall: 0.6711033902583866


100%|█████████████████████████████████████████| 655/655 [00:15<00:00, 40.96it/s]


Epoch: 030, loss: 0.09584542667166195, f1: 0.3787239389038873, precision: 0.26108323612911555, recall: 0.6893246323552751
val_loss: 0.08385812991215935, val_f1: 0.36297339571524545, val_precision: 0.24874804680739795, val_recall: 0.6711798427848641


100%|█████████████████████████████████████████| 655/655 [00:16<00:00, 40.30it/s]


Epoch: 031, loss: 0.0956208439174158, f1: 0.3786093784991278, precision: 0.2610810460650739, recall: 0.6885814283014708
val_loss: 0.08369522034051809, val_f1: 0.3629775217203757, val_precision: 0.24875192233967944, val_recall: 0.6711798427848641


100%|█████████████████████████████████████████| 655/655 [00:16<00:00, 40.80it/s]


Epoch: 032, loss: 0.09573921990360468, f1: 0.3787226226428873, precision: 0.2611363643794569, recall: 0.688945842403492
val_loss: 0.08389209869308745, val_f1: 0.3629540015211826, val_precision: 0.24874033156030495, val_recall: 0.6711033902583866


100%|█████████████████████████████████████████| 655/655 [00:16<00:00, 39.34it/s]


Epoch: 033, loss: 0.09568954080855119, f1: 0.37875077902219245, precision: 0.2611337415077759, recall: 0.6891504979770602
val_loss: 0.08359320904557943, val_f1: 0.36297534601710163, val_precision: 0.24874987871035797, val_recall: 0.6711798427848641


100%|█████████████████████████████████████████| 655/655 [00:16<00:00, 39.67it/s]


Epoch: 034, loss: 0.09562204805352079, f1: 0.3786249329402602, precision: 0.2610721033749007, recall: 0.6887465704715889
val_loss: 0.08361464867628685, val_f1: 0.36297659941592114, val_precision: 0.2487510560217117, val_recall: 0.6711798427848641


100%|█████████████████████████████████████████| 655/655 [00:16<00:00, 39.80it/s]


Epoch: 035, loss: 0.09563230364405699, f1: 0.3786805917821402, precision: 0.2611145216516639, recall: 0.6888197018786241
val_loss: 0.08379442365149362, val_f1: 0.3629677212188695, val_precision: 0.24874271683975918, val_recall: 0.6711798427848641


100%|█████████████████████████████████████████| 655/655 [00:16<00:00, 40.12it/s]


Epoch: 036, loss: 0.09547893303865922, f1: 0.3787183058132012, precision: 0.2611370936718606, recall: 0.6889121967417593
val_loss: 0.08373920784025256, val_f1: 0.3629863821839852, val_precision: 0.24876024503124142, val_recall: 0.6711798427848641


100%|█████████████████████████████████████████| 655/655 [00:16<00:00, 40.16it/s]


Epoch: 037, loss: 0.09572350312706851, f1: 0.3786930625411489, precision: 0.26108362551579495, recall: 0.6891173859645847
val_loss: 0.08347758601931844, val_f1: 0.36292703902807305, val_precision: 0.2487360134532829, val_recall: 0.6709504849320158


100%|█████████████████████████████████████████| 655/655 [00:16<00:00, 40.24it/s]


Epoch: 038, loss: 0.09575032187654203, f1: 0.37869855030593136, precision: 0.26109498288937427, recall: 0.689074612488501
val_loss: 0.08385665388777852, val_f1: 0.36294989074769296, val_precision: 0.24873647019437967, val_recall: 0.6711033902583866


100%|█████████████████████████████████████████| 655/655 [00:16<00:00, 40.32it/s]


Epoch: 039, loss: 0.09547612325622494, f1: 0.3786894157234017, precision: 0.2611076933566396, recall: 0.688925628972645
val_loss: 0.08352612868691828, val_f1: 0.3629823819465248, val_precision: 0.2487564875633208, val_recall: 0.6711798427848641


100%|█████████████████████████████████████████| 655/655 [00:16<00:00, 40.89it/s]


val_loss: 0.08369231801083184, val_f1: 0.36297770195668516, val_precision: 0.2487520916353262, val_recall: 0.6711798427848641


100%|█████████████████████████████████████████| 655/655 [00:16<00:00, 40.31it/s]


Epoch: 041, loss: 0.09548125305561392, f1: 0.3787181593427682, precision: 0.26116326911701954, recall: 0.6887291210412069
val_loss: 0.083739779720743, val_f1: 0.3629431933812641, val_precision: 0.2487406828654332, val_recall: 0.6710269374584933


100%|█████████████████████████████████████████| 655/655 [00:16<00:00, 40.55it/s]


Epoch: 042, loss: 0.095411847302875, f1: 0.37869230020415223, precision: 0.26112269633485163, recall: 0.6888402946407103
val_loss: 0.08372340859939528, val_f1: 0.36295556990162003, val_precision: 0.2487418047929986, val_recall: 0.6711033902583866


100%|█████████████████████████████████████████| 655/655 [00:16<00:00, 40.71it/s]


Epoch: 043, loss: 0.09558445558037239, f1: 0.3787485809656111, precision: 0.2611294394003526, recall: 0.6891659073474753
val_loss: 0.08370004256222616, val_f1: 0.3629620371224464, val_precision: 0.24874612920567257, val_recall: 0.6711161322549942


 74%|██████████████████████████████▏          | 483/655 [00:11<00:04, 40.79it/s]


KeyboardInterrupt: 

In [None]:
wandb.finish()