In [2]:
import os.path as osp
from tqdm.auto import tqdm
import numpy as np
import wandb

import torch
from sklearn.metrics import roc_auc_score

import torch.nn.functional as F
# from ogb.nodeproppred import Evaluator, PygNodePropPredDataset
from torch.nn import LayerNorm, Linear, ReLU

import torch_geometric.transforms as T
from torch_geometric.datasets import Planetoid
from torch_geometric.nn import GCNConv, GatedGraphConv
from torch_geometric.nn import DeepGCNLayer, GENConv
from torch_geometric.utils import negative_sampling
from torch_geometric.loader import DataLoader

from dataset_processing import RNADataset


In [3]:
dataset = RNADataset(root="./data/")
dataset = dataset.shuffle()
train_data, val_data, test_data = dataset[0:655], dataset[655:873], dataset[873:]
bs = 4

train_dataloader = DataLoader(train_data, batch_size=bs, shuffle=False)
val_dataloader = DataLoader(val_data, batch_size=1, shuffle=False)


In [4]:
train_data[1].adj_mat

torch.Size([196, 196])

In [4]:
def del_nucl_conn(y):
#     main_diag = torch.zeros([1, y.size()[0] // bs])
#     diag1 = torch.ones([1, (y.size()[0] // bs) - 1])
    
    main_diag = torch.zeros([1, y.size()[0]])
    diag1 = torch.ones([1, y.size()[0] - 1])
#     batch_diags = []
#     for i in range(bs):
#         batch_diags.append((torch.diag_embed(main_diag).to(device) +
#             torch.diag_embed(diag1, offset=1).to(device) + 
#             torch.diag_embed(diag1, offset=-1).to(device) - 1)[0])

#     return y * torch.cat(batch_diags, 0).to(device) * (-1)
    return y * (torch.diag_embed(main_diag).to(device) + 
                    torch.diag_embed(diag1, offset=1).to(device) + 
                torch.diag_embed(diag1, offset=-1).to(device) - 1)[0] * (-1)


def del_main_diag(y):
#     main_diag = torch.ones([1, y.size()[0] // bs])
    main_diag = torch.zeros([1, y.size()[0]])

#     batch_diags = []
#     for i in range(bs):
#         batch_diags.append((torch.diag_embed(main_diag).to(device) - 1)[0])
    
#     return y * torch.cat(batch_diags, 0).to(device) * (-1)
    return y * (torch.diag_embed(main_diag).to(device) - 1)[0] * (-1)


def adj_mat_split(y):
    y = list(torch.split(y, 196, dim=0))
#     print(y)
    for i in range(len(y)):
        y[i] = torch.split(y[i], 196, dim=1)
    
    y1 = []
    for i in range(len(y)):
        y1.append(y[i][i])
    
    return torch.cat(y1, 0).to(device)
    

def argm(y):
    max_ind = torch.argmax(y, dim=1)
    y1 = torch.zeros_like(y)
    k = 0
    for i in max_ind:
        y1[k][i] = 1.
        y1[i][k] = 1.
        k += 1
    
    return y1 * y


In [5]:
a = torch.rand([4, 4])
print(a)
adj_mat_split(a)

tensor([[0.4004, 0.8519, 0.5511, 0.0188],
        [0.8301, 0.4703, 0.4650, 0.6933],
        [0.4348, 0.7514, 0.2455, 0.0012],
        [0.9674, 0.0052, 0.9604, 0.7480]])


NameError: name 'device' is not defined

In [6]:
def precision(y_pred, y_true):
    y_pred = del_nucl_conn(y_pred)
#     y_pred[(y_pred > 0.5)] = 1
#     y_pred[(y_pred <= 0.5)] = 0 
#     y_pred = argm(y_pred)
    y_pred = del_main_diag(y_pred)
    y_pred = adj_mat_split(y_pred)

    tp = torch.sum(y_pred * y_true)
    fp = torch.sum((1 - y_true) * y_pred)
    
    return tp / (tp + fp + epsilon)

def recall(y_pred, y_true):
    y_pred = del_nucl_conn(y_pred)
#     y_pred[(y_pred > 0.5)] = 1
#     y_pred[(y_pred <= 0.5)] = 0 
#     y_pred = argm(y_pred)
    y_pred = del_main_diag(y_pred)
    y_pred = adj_mat_split(y_pred)

    tp = torch.sum(y_pred * y_true)
    fn = torch.sum(y_true * (1 - y_pred))
    
    return tp / (tp + fn + epsilon)

def f_loss(y_pred, y_true):
    y_pred = del_nucl_conn(y_pred)
#     y_pred = argm(y_pred)
    y_pred = del_main_diag(y_pred)
    y_pred = adj_mat_split(y_pred)

    tp = torch.sum(y_pred * y_true)
    fn = torch.sum(y_true * (1 - y_pred))
    fp = torch.sum((1 - y_true) * y_pred)
    precision = tp / (tp + fp + epsilon)
    recall = tp / (tp + fn + epsilon)

    k = 1
    k1 = 1 - torch.abs(precision - recall)

#     k2 = 1 - torch.abs(K.mean(precision) - K.mean(recall))
    #calculate upgraded f1 score
    f1 = ((1 + k * k) * precision * recall) / (k * k * precision + recall + epsilon)
#     tw = K.sum(K.cast(y_true * y_pred, ’float32’), axis=[1, 2, 3])
#     fw = K.sum(K.cast((1 - y_true) * y_pred, ’float32’), axis=[1, 2, 3])
#     fb = K.sum(K.cast(y_true * (1 - y_pred), ’float32’), axis=[1, 2, 3])
    return 1 - f1


In [28]:
class DeepGCNModel(torch.nn.Module):
    def __init__(self, num_layers, out_channels):
        super(DeepGCNModel, self).__init__()
        self.node_encoder = Linear(dataset.num_features, out_channels)
#         self.edge_encoder = Linear(train_data.edge_attr.size(-1), hidden_channels)

        self.layers = torch.nn.ModuleList()
        for i in range(1, num_layers + 1):
#             conv = GENConv(hidden_channels, hidden_channels, aggr='softmax',
#                            t=1.0, learn_t=True, num_layers=2, norm='layer')
            conv = GatedGraphConv(out_channels, num_layers=4)
            norm = LayerNorm(out_channels, elementwise_affine=False)
            act = ReLU()

            layer = DeepGCNLayer(conv, block='res')
            self.layers.append(layer)

        self.lin = Linear(out_channels, 1)

    def forward(self, x, edge_index):
#         x = self.node_encoder(x)
#         edge_attr = self.edge_encoder(edge_attr)
#         print(x)
        x = self.layers[0].conv(x, edge_index)
#         print(x)
        for layer in self.layers[1:]:
            x = layer(x, edge_index)
#         x = self.layers[0].act(self.layers[0].norm(x))
#         x = F.dropout(x, p=0.1, training=self.training)
#         print(x)
        s = x.size()[0]
#         for i in range(s):
#             print(self.lin(x[i]))
#             x[i] = self.lin(x[i])
        
#         return self.lin(x)
        return x.sigmoid()

In [29]:
num_layers = 6
# num_features = dataset.num_features
out_channels = 200


lr = 0.00001
epochs = 1500
beta = 1


model = DeepGCNModel(num_layers, out_channels)
# model = torch.load("./models1/GatedGCN_6_4e-05_Adam_200_2_60.pt")

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)
train_data = list(map(lambda x: x.to(device), train_data))
val_data = list(map(lambda x: x.to(device), val_data))
print(device)

optimizer = torch.optim.Adam(model.parameters(), lr=lr)
# optimizer = torch.optim.Adagrad(model.parameters(), lr=lr)
# optimizer = torch.optim.RMSprop(model.parameters(), lr=lr)
# optimizer = torch.optim.Adamax(model.parameters(), lr=lr)
# optimizer = torch.optim.Adadelta(model.parameters(), lr=lr)

# optimizer = torch.optim.SGD(model.parameters(), lr=0.001)

# criterion = torch.nn.BCELoss()
# criterion = torch.nn.BCEWithLogitsLoss()
# criterion = torch.nn.CrossEntropyLoss()
# criterion = torch.nn.KLDivLoss()

# def RMSELoss(y_pred, y_true):
#     return torch.sqrt(torch.mean((y_pred - y_true) ** 2))

criterion = f_loss

cuda


In [31]:
model(train_data[0].x, train_data[0].edge_index)

tensor([[-0.0606],
        [-0.1172],
        [-0.1171],
        [-0.1886],
        [-0.1391],
        [-0.1822],
        [-0.1245],
        [-0.1004],
        [-0.0970],
        [-0.1148],
        [-0.1130],
        [-0.1252],
        [-0.1241],
        [-0.0981],
        [-0.1179],
        [-0.0809],
        [-0.1006],
        [-0.1143],
        [-0.1180],
        [-0.1343],
        [-0.1245],
        [-0.1651],
        [-0.0910],
        [-0.0969],
        [-0.0785],
        [-0.0734],
        [-0.0640],
        [-0.0574],
        [-0.0708],
        [-0.0807],
        [-0.0749],
        [-0.0888],
        [-0.0859],
        [-0.1144],
        [-0.0947],
        [-0.1097],
        [-0.0997],
        [-0.1064],
        [-0.1456],
        [-0.1173],
        [-0.1662],
        [-0.1096],
        [-0.0804]], device='cuda:0', grad_fn=<AddmmBackward0>)

In [36]:
run_name = "ResGCN_bs4_" + str(num_layers) + "*4_" + str(lr) + "_" + "Adam_" + str(out_channels) + "_" + str(beta)


wandb.init(
    # set the wandb project where this run will be logged
    project="secondary_structure_prediction1",
    
    # track hyperparameters and run metadata
    config={
    "learning_rate": lr,
    "architecture": "DeepGCN",
    "epochs": epochs,
    "optimizer": "Adam",
    "out_channels": out_channels,
    "loss": "f_loss",
    "beta": beta,
    "train:val:test": "655:218:218"
    },
    name=run_name
)

epsilon = 1e-10

VBox(children=(Label(value='0.002 MB of 0.002 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
train_f1,▁▅▆▆▆▆▆▆▇██████▆▇
train_loss,█▄▃▃▃▃▃▃▂▁▁▁▁▁▁▃▂
train_precision,▁▅▆▆▆▆▆▆▇██████▆▇
train_recall,▃▅▄▅▄▄▃▃▄██████▁▄
val_f1,▁▅▆▆▆▆▆█▇██████▇█
val_loss,█▄▃▃▃▃▃▁▂▁▁▁▁▁▁▂▁
val_precision,▁▅▆▆▆▆▆█▇██████▇█
val_recall,▁▆▇▅▅▃▂▇▅█████▇▅█

0,1
train_f1,0.00369
train_loss,0.99631
train_precision,0.00185
train_recall,0.97671
val_f1,0.00356
val_loss,0.99645
val_precision,0.00178
val_recall,0.98578


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.016669799849235764, max=1.0…

In [37]:
def train():
    for epoch in range(1, epochs + 1):
        model.train()
        train_loss = []
        train_recall = []
        train_precision = []
        for g in tqdm(train_dataloader):
            g.to(device)
            optimizer.zero_grad()

            out = model(g.x, g.edge_index)
            y_true = g.adj_mat
#             print(out.size(),y_true.size())
            loss = criterion(out, y_true.to(torch.float32))            
            
            loss.backward()
            optimizer.step()
            
#             out = out.sigmoid()
            
            train_loss.append(loss.item())
            train_precision.append(precision(out, y_true).item())
            train_recall.append(recall(out, y_true).item())
            
        train_prec = np.mean(train_precision)
        train_rec = np.mean(train_recall)
        train_f1 = (2 * train_prec * train_rec) / (train_prec + train_rec)
        print(f'Epoch: {epoch:03d}, loss: {np.mean(train_loss)}, f1: {train_f1}, precision: {train_prec}, recall: {train_rec}')
        
        val_loss = []
        val_recall = []
        val_precision = []
        with torch.no_grad():
            for g in val_dataloader:
                g.to(device)
                out = model(g.x, g.edge_index)
        
                y_true = g.adj_mat
                loss = criterion(out, y_true.to(torch.float32))
                                
                val_loss.append(loss.item())
                
#                 out = out.sigmoid()
                
                val_precision.append(precision(out, y_true).item())
                val_recall.append(recall(out, y_true).item())
 
            
            prec = np.mean(val_precision)
            rec = np.mean(val_recall)
            f1 = (2 * prec * rec) / (prec + rec)
            print(f'val_loss: {np.mean(val_loss)}, val_f1: {f1}, val_precision: {prec}, val_recall: {rec}')
              
            wandb.log({"train_loss": np.mean(train_loss), "train_f1": train_f1, "train_precision": train_prec, 
                       "train_recall": train_rec,
                       "val_loss": np.mean(val_loss), "val_f1": f1, "val_precision": prec, "val_recall": rec})
#     wandb.finish()

In [38]:
train()
torch.save(model, "./models1/" + run_name + "_" + str(epochs) + ".pt")

  0%|          | 0/164 [00:00<?, ?it/s]

Epoch: 001, loss: 0.9970172232970959, f1: 0.0029833825243262557, precision: 0.0014939595653786233, recall: 0.9824640790863735
val_loss: 0.9962828766315355, val_f1: 0.003720608049378942, val_precision: 0.0018638778751456486, val_recall: 0.9702083383678296


  0%|          | 0/164 [00:00<?, ?it/s]

Epoch: 002, loss: 0.9968186758640336, f1: 0.003182029413234458, precision: 0.0015935917907016828, recall: 0.9838359526744703
val_loss: 0.9961646035176899, val_f1: 0.0038388646519574845, val_precision: 0.0019232413504291924, val_recall: 0.9691540883221758


  0%|          | 0/164 [00:00<?, ?it/s]

Epoch: 003, loss: 0.9967259393959511, f1: 0.003274774779179742, precision: 0.00164012404292873, recall: 0.9813148004979622
val_loss: 0.9961094238342495, val_f1: 0.003893997244927448, val_precision: 0.0019509260329885682, val_recall: 0.9671640575205515


  0%|          | 0/164 [00:00<?, ?it/s]

Epoch: 004, loss: 0.9966351219066759, f1: 0.0033655874889711016, precision: 0.00168568465862957, recall: 0.981232737259167
val_loss: 0.9960243406645749, val_f1: 0.0039790941895839135, val_precision: 0.0019936423139019716, val_recall: 0.9685550802344576


  0%|          | 0/164 [00:00<?, ?it/s]

Epoch: 005, loss: 0.9965501293176557, f1: 0.0034505760825444343, precision: 0.0017283305118681023, recall: 0.9800811115561462
val_loss: 0.9959405778198067, val_f1: 0.004062870935398952, val_precision: 0.0020357006779629437, val_recall: 0.9695640550840885


  0%|          | 0/164 [00:00<?, ?it/s]

Epoch: 006, loss: 0.9964173577907609, f1: 0.0035833743519485617, precision: 0.0017949578575165261, recall: 0.9832822098964598
val_loss: 0.9958375317787905, val_f1: 0.004165932087304215, val_precision: 0.0020874479478235793, val_recall: 0.9701419367429314


  0%|          | 0/164 [00:00<?, ?it/s]

Epoch: 007, loss: 0.9962841872035003, f1: 0.0037165604971285205, precision: 0.0018618098895275602, recall: 0.9802029668921377
val_loss: 0.9956986783841334, val_f1: 0.00430485430876195, val_precision: 0.00215719986036516, val_recall: 0.9728685513548895


  0%|          | 0/164 [00:00<?, ?it/s]

Epoch: 008, loss: 0.9960891880640169, f1: 0.0039115721751522975, precision: 0.001959671839580462, recall: 0.9863467652623247
val_loss: 0.9955269185774916, val_f1: 0.004476634587582354, val_precision: 0.0022434534164174425, val_recall: 0.9776948381454573


  0%|          | 0/164 [00:00<?, ?it/s]

Epoch: 009, loss: 0.9958843243558232, f1: 0.004116461310347417, precision: 0.0020625305297449457, recall: 0.9872761386923674
val_loss: 0.9953811234290447, val_f1: 0.004622482359945438, val_precision: 0.0023167392893410246, val_recall: 0.9738880931784254


  0%|          | 0/164 [00:00<?, ?it/s]

Epoch: 010, loss: 0.9956619630499584, f1: 0.004338843504896178, precision: 0.0021741930122072695, recall: 0.9885736373139591
val_loss: 0.9952078125345598, val_f1: 0.004795823085717932, val_precision: 0.0024038097736267627, val_recall: 0.9772630861592949


  0%|          | 0/164 [00:00<?, ?it/s]

Epoch: 011, loss: 0.9954580661000275, f1: 0.004542750175005795, precision: 0.0022766025812881886, recall: 0.9891964676903515
val_loss: 0.9950460589260136, val_f1: 0.004957618067808758, val_precision: 0.0024851004215525547, val_recall: 0.97913047479927


  0%|          | 0/164 [00:00<?, ?it/s]

Epoch: 012, loss: 0.9952940755501026, f1: 0.004706756674658204, precision: 0.0023589945904393794, recall: 0.9884894177681063
val_loss: 0.9948892836723853, val_f1: 0.005114459859800882, val_precision: 0.002563902869934874, val_recall: 0.9825487899670907


  0%|          | 0/164 [00:00<?, ?it/s]

Epoch: 013, loss: 0.995163792517127, f1: 0.004837039987525067, precision: 0.002424448887922619, recall: 0.988983434511394
val_loss: 0.9948016714065446, val_f1: 0.005202142754719768, val_precision: 0.002607976532740937, val_recall: 0.9823867440770525


  0%|          | 0/164 [00:00<?, ?it/s]

Epoch: 014, loss: 0.9950815820839347, f1: 0.004919259718614511, precision: 0.0024657604879438424, recall: 0.9892718155936497
val_loss: 0.9947826206137281, val_f1: 0.005221111229550582, val_precision: 0.0026175318293299083, val_recall: 0.9795014703765922


  0%|          | 0/164 [00:00<?, ?it/s]

Epoch: 015, loss: 0.9950211258923135, f1: 0.004979716378445942, precision: 0.0024961324070762025, recall: 0.9905642201987709
val_loss: 0.9947482396703248, val_f1: 0.005255510237604673, val_precision: 0.002634827583604236, val_recall: 0.9789630438483089


  0%|          | 0/164 [00:00<?, ?it/s]

Epoch: 016, loss: 0.9949890618644109, f1: 0.005011781884033899, precision: 0.0025122597807927466, recall: 0.9884767459660042
val_loss: 0.9947330683743189, val_f1: 0.005270663645891786, val_precision: 0.0026424549610264293, val_recall: 0.9776232849567308


  0%|          | 0/164 [00:00<?, ?it/s]

Epoch: 017, loss: 0.9949494836533942, f1: 0.005051363349964549, precision: 0.002532141103917483, recall: 0.9900847967804932
val_loss: 0.9946591433035125, val_f1: 0.005344655703219712, val_precision: 0.002679607422687028, val_recall: 0.9836828943786271


  0%|          | 0/164 [00:00<?, ?it/s]

Epoch: 018, loss: 0.9949167476194661, f1: 0.005084099733404975, precision: 0.0025485849125799157, recall: 0.9913671699965872
val_loss: 0.9946370444713383, val_f1: 0.005366793228357363, val_precision: 0.002690733408815988, val_recall: 0.9841225324967585


  0%|          | 0/164 [00:00<?, ?it/s]

Epoch: 019, loss: 0.9948991809676333, f1: 0.005101665127204741, precision: 0.002557410735412637, recall: 0.9916929460880233
val_loss: 0.9946618987879622, val_f1: 0.0053418390331499265, val_precision: 0.002678207685309035, val_recall: 0.9814917076071468


  0%|          | 0/164 [00:00<?, ?it/s]

Epoch: 020, loss: 0.9948814863838801, f1: 0.005119361106530381, precision: 0.0025663124422241773, recall: 0.9905081482195273
val_loss: 0.9946156398418846, val_f1: 0.005388160462556036, val_precision: 0.002701471423128953, val_recall: 0.9846829717312384


  0%|          | 0/164 [00:00<?, ?it/s]

Epoch: 021, loss: 0.9948647978102289, f1: 0.005136051873786606, precision: 0.0025746919198398957, recall: 0.9918830700036956
val_loss: 0.9945983326216357, val_f1: 0.005405502430346675, val_precision: 0.002710189238573006, val_recall: 0.984800243869834


  0%|          | 0/164 [00:00<?, ?it/s]

Epoch: 022, loss: 0.9948548540109541, f1: 0.005145995149769182, precision: 0.002579693346732955, recall: 0.9913038967586145
val_loss: 0.9946307693052729, val_f1: 0.005372960296376774, val_precision: 0.0026938508844393983, val_recall: 0.9818526503689792


  0%|          | 0/164 [00:00<?, ?it/s]

Epoch: 023, loss: 0.9948391645419888, f1: 0.005161685189639637, precision: 0.002587574446806684, recall: 0.9920185949744248
val_loss: 0.9945798683057138, val_f1: 0.005423957227225156, val_precision: 0.002719466092007303, val_recall: 0.9849956812114891


  0%|          | 0/164 [00:00<?, ?it/s]

Epoch: 024, loss: 0.9948302714562998, f1: 0.005170577922899805, precision: 0.002592050212026551, recall: 0.9911182442816292
val_loss: 0.9945830046583753, val_f1: 0.005420764387772206, val_precision: 0.0027178603969674575, val_recall: 0.9850548781933041


  0%|          | 0/164 [00:00<?, ?it/s]

Epoch: 025, loss: 0.9948242859142583, f1: 0.005176563502226152, precision: 0.0025950600914855875, recall: 0.9909131599635612
val_loss: 0.9945743111295438, val_f1: 0.005429464462003949, val_precision: 0.0027222341473092573, val_recall: 0.9850998701305564


  0%|          | 0/164 [00:00<?, ?it/s]

Epoch: 026, loss: 0.9948128741688844, f1: 0.005187976750817617, precision: 0.0026007917938220154, recall: 0.9916226329599939
val_loss: 0.9945664124204479, val_f1: 0.005437367310510458, val_precision: 0.002726207207995168, val_recall: 0.985132321578647


  0%|          | 0/164 [00:00<?, ?it/s]

Epoch: 027, loss: 0.9948084783263322, f1: 0.005192367172881898, precision: 0.0026029990535235124, recall: 0.9915480973517022
val_loss: 0.9945652257958684, val_f1: 0.005438542871868034, val_precision: 0.002726798097190763, val_recall: 0.9851518388188213


  0%|          | 0/164 [00:00<?, ?it/s]

Epoch: 028, loss: 0.9948020439322401, f1: 0.005198805994416198, precision: 0.002606235834608059, recall: 0.9914852122708064
val_loss: 0.9945927193952263, val_f1: 0.005411010180913948, val_precision: 0.002712978793061424, val_recall: 0.9821073065657134


  0%|          | 0/164 [00:00<?, ?it/s]

Epoch: 029, loss: 0.9947973682386119, f1: 0.0052034804261243805, precision: 0.0026085853655453454, recall: 0.991484333102296
val_loss: 0.9945601632288836, val_f1: 0.0054435901335992895, val_precision: 0.0027293355517631927, val_recall: 0.9851743059420804


  0%|          | 0/164 [00:00<?, ?it/s]

Epoch: 030, loss: 0.9947895055137029, f1: 0.005211342521127428, precision: 0.0026125354041407886, recall: 0.9917345613968082
val_loss: 0.9945521083993649, val_f1: 0.005451661288021997, val_precision: 0.0027333934971082254, val_recall: 0.985179489905681


  0%|          | 0/164 [00:00<?, ?it/s]

Epoch: 031, loss: 0.9947820691073813, f1: 0.00521878022014237, precision: 0.0026162736768771825, recall: 0.9917647129878765
val_loss: 0.994554357244334, val_f1: 0.0054493825402224215, val_precision: 0.002732247767282154, val_recall: 0.9851828791679592


  0%|          | 0/164 [00:00<?, ?it/s]

Epoch: 032, loss: 0.9947805186597313, f1: 0.005220330372410802, precision: 0.002617054779675971, recall: 0.9914874230943075
val_loss: 0.994574895419112, val_f1: 0.005428843134122581, val_precision: 0.0027219445322888025, val_recall: 0.9821266906523923


  0%|          | 0/164 [00:00<?, ?it/s]

Epoch: 033, loss: 0.9947853233756089, f1: 0.00521552155409893, precision: 0.0026146484774582815, recall: 0.989935296337779
val_loss: 0.9945424989275976, val_f1: 0.0054612712249793366, val_precision: 0.002738225141619204, val_recall: 0.9851856365663196


  0%|          | 0/164 [00:00<?, ?it/s]

Epoch: 034, loss: 0.9947704730964289, f1: 0.005230372685549888, precision: 0.0026220991677269566, recall: 0.9919646425945002
val_loss: 0.9945397778935389, val_f1: 0.005463983579831026, val_precision: 0.0027395888658105867, val_recall: 0.9851861943345551


  0%|          | 0/164 [00:00<?, ?it/s]

Epoch: 035, loss: 0.9947716401117604, f1: 0.005229203716129787, precision: 0.002621518799015775, recall: 0.9909334157298251
val_loss: 0.9945343713694756, val_f1: 0.005469420060310897, val_precision: 0.002742322249906223, val_recall: 0.9851865038412426


  0%|          | 0/164 [00:00<?, ?it/s]

Epoch: 036, loss: 0.9947719824750249, f1: 0.005228861320885695, precision: 0.0026213530253721175, recall: 0.9900295334618266
val_loss: 0.9945271586606262, val_f1: 0.0054766511617369215, val_precision: 0.0027459579688943326, val_recall: 0.9851866708982975


  0%|          | 0/164 [00:00<?, ?it/s]

Epoch: 037, loss: 0.9947640132613298, f1: 0.005236833108569776, precision: 0.002625351716698965, recall: 0.991218948509635
val_loss: 0.9945294457838076, val_f1: 0.005474330060498955, val_precision: 0.0027447909410720384, val_recall: 0.9851866818349296


  0%|          | 0/164 [00:00<?, ?it/s]

Epoch: 038, loss: 0.9947612500045357, f1: 0.005239597741274552, precision: 0.0026267434920400107, recall: 0.9909164567546147
val_loss: 0.9945317189627831, val_f1: 0.005472028386787529, val_precision: 0.0027436336831920226, val_recall: 0.9851867903810029


  0%|          | 0/164 [00:00<?, ?it/s]

Epoch: 039, loss: 0.9947655687971813, f1: 0.00523527501415216, precision: 0.0026245852027275804, recall: 0.9888476714855288
val_loss: 0.9945205704334679, val_f1: 0.0054832235505000385, val_precision: 0.0027492625176652394, val_recall: 0.9851868070593668


  0%|          | 0/164 [00:00<?, ?it/s]

Epoch: 040, loss: 0.9947498895046187, f1: 0.005250959161787035, precision: 0.0026324463659675975, recall: 0.9920550308576445
val_loss: 0.9945557847482349, val_f1: 0.005447938465042873, val_precision: 0.002731545181189869, val_recall: 0.982140370328492


  0%|          | 0/164 [00:00<?, ?it/s]

Epoch: 041, loss: 0.9947445134564143, f1: 0.005256334328059715, precision: 0.0026351482385241357, recall: 0.9920550308576445
val_loss: 0.9945245677724891, val_f1: 0.005479173387986338, val_precision: 0.002747226122381174, val_recall: 0.9851868160820882


  0%|          | 0/164 [00:00<?, ?it/s]

Epoch: 042, loss: 0.9947407739918407, f1: 0.005260073813619256, precision: 0.0026370279310645945, recall: 0.9920550308576445
val_loss: 0.9945350220990837, val_f1: 0.005468690715211961, val_precision: 0.0027419614684105844, val_recall: 0.9844223389384943


  0%|          | 0/164 [00:00<?, ?it/s]

Epoch: 043, loss: 0.9947379126054484, f1: 0.00526293337492771, precision: 0.0026384673316298597, recall: 0.9917714319578032
val_loss: 0.9945114006143098, val_f1: 0.005492379819929859, val_precision: 0.0027538662608937196, val_recall: 0.9851868171757514


  0%|          | 0/164 [00:00<?, ?it/s]

Epoch: 044, loss: 0.9947327113733059, f1: 0.005268134581965831, precision: 0.002641079786926417, recall: 0.9920550308576445
val_loss: 0.9945128059715306, val_f1: 0.005490950637834229, val_precision: 0.002753147669830744, val_recall: 0.9851868171757514


  0%|          | 0/164 [00:00<?, ?it/s]

Epoch: 045, loss: 0.9947294291199708, f1: 0.005271417465046268, precision: 0.002642729982631508, recall: 0.9920550308576445
val_loss: 0.9945021318186313, val_f1: 0.005501672988918722, val_precision: 0.0027585388807400447, val_recall: 0.9851868171757514


  0%|          | 0/164 [00:00<?, ?it/s]

Epoch: 046, loss: 0.9947260843544472, f1: 0.005274764437320921, precision: 0.002644412399471238, recall: 0.9920550308576445


KeyboardInterrupt: 