In [1]:
import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)
import obonet
import random
import torch
import math
from Bio import SeqIO
import Bio.PDB
import urllib.request
import py3Dmol
import pylab
import pickle as pickle
import torch.nn as nn
from torch.nn import Dropout
import torch.nn.functional as F
from torch import Tensor
from torch.nn import Sequential, Linear, ReLU
from torch_geometric.utils import erdos_renyi_graph
from torch_geometric.utils import softmax
from torch_geometric.nn import MessagePassing
from torch_geometric.nn import LayerNorm
from torch_geometric.nn.models import MLP
from torch_geometric.data import Data
import matplotlib.pyplot as plt
import os
from Bio import PDB
from rdkit import Chem
import blosum as bl

In [2]:
def createMask(graph,indicies,num_masked):
    size = graph.x.size()[0]
    protein_mask = [False]*size
    true_mask = [True] * num_masked
    indicies_mask = [False]*(len(indicies) - num_masked)
    design_mask = np.hstack((true_mask,indicies_mask))
    random.shuffle(design_mask)
    
    count = 0
    for i in range(len(protein_mask)):
        if i in indicies:
            protein_mask[i] = design_mask[count]
            count += 1
            
    for i, j in enumerate(protein_mask):
        if j == 1.0:
            protein_mask[i] = True
    
    return protein_mask

In [3]:
graph_list = torch.load('binding_pocket_graphs.pt')

smallest = 5
count = 0
graph_list_clean = []
for entry in graph_list:
    if len(entry.designable_indicies) > smallest:
        graph_list_clean.append(entry)
        count += 1
print(count)
        
        
for i, graph in enumerate(graph_list_clean):
    graph_list_clean[i].mask = createMask(graph,graph.designable_indicies,int(len(graph.designable_indicies)))
    graph_list_clean[i].inv_mask = [not i for i in graph_list_clean[i].mask]

5316


In [4]:
group_assignment = torch.load('group_assignment_30p.pt')
group_size = torch.load('group_size_30p.pt')

In [23]:
for i, graph in enumerate(graph_list_clean):
    graph_list_clean[i].group = group_assignment[graph.label]
    graph_list_clean[i].weight = 1.0/group_size[graph_list_clean[i].group]

In [25]:
print(random.choice(graph_list))

Data(x=[89, 133], edge_index=[2, 2289], edge_attr=[2289, 114], pos=[88, 3], y=[89, 20], label='1ndw', designable_indicies=[23], mask=[89], inv_mask=[89], Kd=[11], group='1ndw', weight=0.1111111111111111)


In [26]:
class Graph_Attn(MessagePassing):
    def __init__(self, in_channels, edge_attr_dim, hidden_channels, out_channels, nheads = 8, attn_dropout = 0.5, mlp_dropout = 0.0, neg_slope = 0.2):
        super(Graph_Attn, self).__init__(node_dim=0, aggr='add')  # 'add' aggregation for summing messages
        
        self.in_channels = float(in_channels)
        self.neg_slope = neg_slope
        self.attn_dropout = attn_dropout
        self.relu = nn.LeakyReLU(negative_slope = neg_slope)
        
        self.nheads = nheads
        self.c = hidden_channels        
        
        self.Wq = Linear(in_channels, nheads*hidden_channels)
        self.Wz = Linear((in_channels+edge_attr_dim), nheads*hidden_channels)
        self.Wv = Linear((in_channels+edge_attr_dim), nheads*hidden_channels)
        self.W0 = Linear(nheads*hidden_channels, out_channels)
        #self.W0 = MLP(in_channels=nheads*out_channels, hidden_channels= 2*nheads*out_channels ,out_channels=out_channels, num_layers=2, norm = 'layer', dropout = mlp_dropout)
        
        self.dropout = nn.Dropout(p = self.attn_dropout)

    def forward(self, x, edge_index, edge_attr)-> Tensor:
        out = self.propagate(edge_index, x=x, edge_attr=edge_attr)
        return self.W0(out.view(-1,self.nheads*self.c))

    def message(self, x_j: Tensor, x_i: Tensor, edge_attr: Tensor, index)-> Tensor:
        rij = torch.cat([x_j , edge_attr], dim=-1)
        qi = self.Wq(x_i).view(-1,self.nheads,self.c)
        zij = self.Wz(rij).view(-1,self.nheads,self.c)
        vij = self.Wv(rij).view(-1,self.nheads,self.c)
        mij = torch.sum(qi * zij * ((1.0/self.in_channels) ** 0.5), dim = -1)
        alphaij = softmax(mij, index)
        alphaij = self.dropout(alphaij)
        msg = vij*alphaij.unsqueeze(-1)
        return  msg
    
x = torch.rand(14,7)
edge_index  = erdos_renyi_graph(14, 0.5)
edge_attr = torch.rand(edge_index.size()[1], 5)

graph = Data(x = x, edge_index = edge_index, edge_attr = edge_attr)
model = Graph_Attn(7,5,10,7)
out = model(graph.x,graph.edge_index,graph.edge_attr)
#print(out.size())

In [27]:
class MLP_Conv_nodes(MessagePassing):
    def __init__(self, in_channels, edge_attr_dim, hidden_channels, out_channels):
        super(MLP_Conv_nodes, self).__init__(aggr='add')  # 'add' aggregation for summing messages
        #self.mlp = nn.Sequential(
        #    nn.Linear(2*in_channels + edge_attr_dim, hidden_channels),
        #    nn.ReLU(),
        #    nn.Linear(hidden_channels, hidden_channels),
        #    nn.ReLU(),
        #    nn.Linear(hidden_channels, out_channels)
        #    
        #)
        self.mlp=MLP(in_channels= 2 * in_channels + edge_attr_dim, hidden_channels= hidden_channels,out_channels=out_channels, num_layers=2, norm = 'layer', dropout = 0.1, act = 'gelu')

    def forward(self, x, edge_index, edge_attr)-> Tensor:
        return self.propagate(edge_index, x=x, edge_attr=edge_attr)

    def message(self, x_j: Tensor, x_i: Tensor, edge_attr: Tensor)-> Tensor:
        return self.mlp(torch.cat([x_j, x_i, edge_attr], dim=-1))

In [28]:
class MLP_Conv_edges(MessagePassing):
    def __init__(self, in_channels, edge_attr_dim, hidden_channels, out_channels):
        super(MLP_Conv_edges, self).__init__(aggr='add')  # 'add' aggregation for summing messages
        #self.mlp = nn.Sequential(
        #    nn.Linear(2*in_channels + edge_attr_dim, hidden_channels),
        #    nn.ReLU(),
        #    nn.Linear(hidden_channels, hidden_channels),
        #    nn.ReLU(),
        #    nn.Linear(hidden_channels, out_channels)
        #    
        #)
        self.mlp=MLP(in_channels= 2 * in_channels + edge_attr_dim, hidden_channels= hidden_channels,out_channels=out_channels, num_layers=2, norm = 'layer', dropout = 0.1, act = 'gelu')

    def forward(self, x, edge_index, edge_attr)-> Tensor:
        return self.edge_updater(edge_index, x=x, edge_attr=edge_attr)

    def edge_update(self, x_j: Tensor, x_i: Tensor, edge_attr: Tensor)-> Tensor:
        return self.mlp(torch.cat([x_j, x_i, edge_attr], dim=-1))

In [29]:
class Encoder_Layer(torch.nn.Module):
    def __init__(self, node_size, edge_size, hidden_size):
        super().__init__()
        self.dropout = nn.Dropout(p = 0.1)
        self.msg1 = Graph_Attn(node_size,edge_size,hidden_size,node_size, nheads = 3)
        #self.msg1 = MLP_Conv_nodes(node_size,edge_size,hidden_size,node_size)
        self.norm_node1 = LayerNorm(node_size, mode = 'node')
        self.norm_node2 = LayerNorm(node_size, mode = 'node')
        self.norm_edge1 = LayerNorm(edge_size, mode = 'node')
        #self.feed_forward = MLP(in_channels=node_size, hidden_channels= 4*node_size,out_channels=node_size, num_layers=2, norm = 'layer', dropout = 0.1)
        self.feed_forward = MLP(in_channels=node_size, hidden_channels= 4*node_size,out_channels=node_size, num_layers=2, norm = 'layer', dropout = 0.1, act = 'gelu')
        self.edge_message = MLP_Conv_edges(node_size,edge_size,int(4*hidden_size),edge_size)
        
    def forward(self, x, edge_index, edge_attr):
        msg = self.msg1(x,edge_index,edge_attr)
        x1 = self.norm_node1(x + self.dropout(msg))
        x2 = self.feed_forward(x1)
        x3 = self.norm_node2(x1 + self.dropout(x2))
        edge_msg = self.edge_message(x3,edge_index,edge_attr)
        edge_attr1 = self.norm_edge1(edge_attr + self.dropout(edge_msg))
        return x3, edge_attr1

In [30]:
class AA_Classifier(torch.nn.Module):
    def __init__(self):
        super(AA_Classifier, self).__init__()
        self.node_feature_size = 133
        self.node_feature_hidden_size = 128
        self.node_feature_size_out = 133
        self.edge_dim = 114
        self.dropout = 0.1
        self.Droput = nn.Dropout(p = self.dropout)
        #self.ff_out = MLP(in_channels=self.node_feature_size, hidden_channels= 64,out_channels=20, num_layers=2, norm = 'layer', dropout = 0.0)
        self.ff_out = Linear(self.node_feature_size, 20)
        #self.ff_out2 = Linear(self.node_feature_size, 11)
        #self.relu = nn.LeakyReLU(negative_slope = 0.2)
        
        self.conv1 = Encoder_Layer(self.node_feature_size,self.edge_dim,self.node_feature_hidden_size)
        self.conv2 = Encoder_Layer(self.node_feature_size,self.edge_dim,self.node_feature_hidden_size)
        self.conv3 = Encoder_Layer(self.node_feature_size,self.edge_dim,self.node_feature_hidden_size)
        self.conv4 = Encoder_Layer(self.node_feature_size,self.edge_dim,self.node_feature_hidden_size)
        #self.conv5 = Encoder_Layer(self.node_feature_size,self.edge_dim,self.node_feature_hidden_size)
        #self.conv6 = Encoder_Layer(self.node_feature_size,self.edge_dim,self.node_feature_hidden_size)
        #self.conv7 = Encoder_Layer(self.node_feature_size,self.edge_dim,self.node_feature_hidden_size)
        #self.conv8 = Encoder_Layer(self.node_feature_size,self.edge_dim,self.node_feature_hidden_size)
        #self.conv9 = Encoder_Layer(self.node_feature_size,self.edge_dim,self.node_feature_hidden_size)
        
        self.softmax = torch.nn.Softmax(dim=1)
        self.tanh = nn.Tanh()
        
    def forward(self,graph):
        x, edge_index, edge_attr = graph.x,graph.edge_index,graph.edge_attr
        
        x1, new_attr = self.conv1(x, edge_index,edge_attr)
        x1, new_attr = self.conv2(x1, edge_index,new_attr)
        x1, new_attr = self.conv3(x1, edge_index,new_attr)
        x1, new_attr = self.conv4(x1, edge_index,new_attr)
        #x1, new_attr = self.conv5(x1, edge_index,new_attr)
        #x1, new_attr = self.conv6(x1, edge_index,new_attr)    
        #x1, new_attr = self.conv7(x1, edge_index,new_attr)
        #x1, new_attr = self.conv8(x1, edge_index,new_attr)
        #x1, new_attr = self.conv9(x1, edge_index,new_attr)
        
        return self.ff_out(x1)#, self.ff_out2(x1[-1])

In [31]:
#graph_list_clean = torch.load('full_graphs_mn_rm_6_11302023.pt')

In [32]:
train_name = torch.load('train_name_no_sampling.pt')
test_name = torch.load('test_name_no_sampling.pt')

train_data = []
val_data = []

for graph in graph_list_clean:
    if graph.label in train_name:
        train_data.append(graph)
    elif graph.label in test_name:
        val_data.append(graph)

In [33]:
print("Train:", len(train_data))
print("Val:", len(val_data))

Train: 3986
Val: 1330


In [34]:
#idxs = torch.load('idxs_11302023.pt')
#train_idxs = idxs[0].detach().numpy()
#val_idxs = idxs[1].detach().numpy()
#train_data = [graph_list_clean[int(i)] for i in train_idxs]
#val_data = [graph_list_clean[int(i)] for i in val_idxs]
#torch.save(train_data, 'train_data_12152023')
#torch.save(val_data, 'val_data_12152023')

In [35]:
#torch.save(train_data, 'train_data_12202023')
#torch.save(val_data, 'val_data_12202023')

In [36]:
from torch_geometric.loader import DataLoader
train_batch_size = 1
full_dl = DataLoader(graph_list,batch_size = 1, shuffle = True)
train_dl = DataLoader(train_data,batch_size = train_batch_size, shuffle = True)
val_dl = DataLoader(val_data,batch_size = 1, shuffle = True)

In [37]:
with open('AA_embeddings_11172023.pkl', 'rb') as f:
    AA_embeddings = pickle.load(f)

AA_3_letters = ['ALA','ARG','ASN','ASP','CYS','GLN','GLU','GLY','HIS','ILE','LEU','LYS','MET','PHE','PRO','SER','THR','TRP','TYR','VAL']

In [38]:
lossf = torch.nn.CrossEntropyLoss(reduction = 'mean', label_smoothing = 0.1)
def custom_loss(predict, truth, lossf):
    loss1 = lossf(predict[0], truth[0])
    loss2 = lossf(predict[1], truth[1])
    return (loss1 + loss2)
    

In [45]:
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model = AA_Classifier()
#model.load_state_dict(torch.load('9_encder_layers_12092023.pt'))
model.to(DEVICE) # put on GPU

# Define a loss function (e.g., Mean Squared Error) and an optimizer (e.g., Adam)
criterion = torch.nn.CrossEntropyLoss(reduction = 'sum', label_smoothing = 0.1)
#criterion = torch.nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), betas = (0.9,0.98),eps =1e-9, lr=1e-5, weight_decay=1e-4)

# Training loop
num_epochs = 1000  # Adjust the number of epochs as needed
losses = []
smallest = 39.0

for epoch in range(num_epochs):
    total_loss = 0.0
    val_loss = 0.0
        
    for batch in train_dl:
        model.train()
        inputs = batch.to(DEVICE)
        mask = []
        for i in batch.mask:
            mask += i 
        
        
        # Zero the parameter gradients
        optimizer.zero_grad()
        
        # Forward pass
        with torch.no_grad():
            for i, j in enumerate(mask):
                if j == True:
                    p = random.random()
                    if p < 0.8:
                        inputs.x[i] = torch.zeros(133)
                    elif p >= 0.8 and p < 0.92:
                        inputs.x[i] = AA_embeddings[random.choice(AA_3_letters)]
        

        outputs = model(inputs)
        #outputs, MN = model(inputs)

        # Compute the loss
        loss = criterion(outputs[mask], inputs.y[mask])*inputs.weight
        #loss = custom_loss([outputs[mask],MN], [inputs.y[mask],inputs.Kd], lossf)
        
        # Backpropagation and optimization
        loss.backward()
        optimizer.step()
        
        inputs= inputs.to('cpu')
        
        total_loss += loss.item()
        
    for batch in val_dl:
        with torch.no_grad():
            model.eval()
            inputs = batch[0].to(DEVICE)
            
      
            for i, j in enumerate(inputs.mask):
                if j == True:
                    inputs.x[i] = torch.zeros(133)
                    
            outputs = model(inputs)
            #outputs, MN = model(inputs)
            
            loss = criterion(outputs[inputs.mask], inputs.y[inputs.mask])*inputs.weight
            #loss = custom_loss([outputs[inputs.mask],MN], [inputs.y[inputs.mask],inputs.Kd], lossf)
        
            inputs= inputs.to('cpu')
        
        val_loss += loss.item()
    
    # Print the average loss for this epoch
    avg_loss = total_loss / (train_batch_size*len(train_dl))
    avg_val_loss = val_loss / len(val_dl)
    print(f'Epoch [{epoch+1}/{num_epochs}] Loss: {avg_loss:.4f} Val Loss: {avg_val_loss:.4f}')
    losses.append([avg_loss,avg_val_loss])
    
    if avg_val_loss < smallest:
        torch.save(model.state_dict(), '02192024_grouping_split_reweight.pt')
        smallest = avg_val_loss

print('Training complete')


Epoch [1/1000] Loss: 18.6789 Val Loss: 21.8291
Epoch [2/1000] Loss: 18.3039 Val Loss: 21.7545
Epoch [3/1000] Loss: 18.2072 Val Loss: 21.6093
Epoch [4/1000] Loss: 18.1600 Val Loss: 21.6299
Epoch [5/1000] Loss: 18.0932 Val Loss: 21.5989
Epoch [6/1000] Loss: 18.0233 Val Loss: 21.4920
Epoch [7/1000] Loss: 18.0710 Val Loss: 21.4093
Epoch [8/1000] Loss: 17.9798 Val Loss: 21.4368
Epoch [9/1000] Loss: 17.9094 Val Loss: 21.3629
Epoch [10/1000] Loss: 17.9161 Val Loss: 21.4641
Epoch [11/1000] Loss: 17.9201 Val Loss: 21.4089
Epoch [12/1000] Loss: 17.8812 Val Loss: 21.3147
Epoch [13/1000] Loss: 17.8091 Val Loss: 21.3528
Epoch [14/1000] Loss: 17.8378 Val Loss: 21.2082
Epoch [15/1000] Loss: 17.8236 Val Loss: 21.3116
Epoch [16/1000] Loss: 17.8069 Val Loss: 21.2927
Epoch [17/1000] Loss: 17.7780 Val Loss: 21.2753
Epoch [18/1000] Loss: 17.7323 Val Loss: 21.1745
Epoch [19/1000] Loss: 17.7066 Val Loss: 21.2153
Epoch [20/1000] Loss: 17.6974 Val Loss: 21.1496
Epoch [21/1000] Loss: 17.6515 Val Loss: 21.0359
E

Epoch [171/1000] Loss: 13.7094 Val Loss: 18.0419
Epoch [172/1000] Loss: 13.6997 Val Loss: 18.1499
Epoch [173/1000] Loss: 13.6723 Val Loss: 18.0475
Epoch [174/1000] Loss: 13.6594 Val Loss: 18.0794
Epoch [175/1000] Loss: 13.6255 Val Loss: 18.0625
Epoch [176/1000] Loss: 13.5981 Val Loss: 18.1868
Epoch [177/1000] Loss: 13.6038 Val Loss: 18.0495
Epoch [178/1000] Loss: 13.5533 Val Loss: 17.9498
Epoch [179/1000] Loss: 13.5840 Val Loss: 17.9889
Epoch [180/1000] Loss: 13.4811 Val Loss: 18.0531
Epoch [181/1000] Loss: 13.5269 Val Loss: 18.0000
Epoch [182/1000] Loss: 13.4402 Val Loss: 18.0650
Epoch [183/1000] Loss: 13.4542 Val Loss: 18.0282
Epoch [184/1000] Loss: 13.4055 Val Loss: 18.0115
Epoch [185/1000] Loss: 13.3664 Val Loss: 17.9533
Epoch [186/1000] Loss: 13.3639 Val Loss: 17.9433
Epoch [187/1000] Loss: 13.3131 Val Loss: 17.8809
Epoch [188/1000] Loss: 13.3303 Val Loss: 17.8851
Epoch [189/1000] Loss: 13.3037 Val Loss: 17.8243
Epoch [190/1000] Loss: 13.2501 Val Loss: 17.8767
Epoch [191/1000] Los

Epoch [339/1000] Loss: 10.4231 Val Loss: 16.9439
Epoch [340/1000] Loss: 10.4103 Val Loss: 17.0300
Epoch [341/1000] Loss: 10.3830 Val Loss: 17.0154
Epoch [342/1000] Loss: 10.4048 Val Loss: 16.9744
Epoch [343/1000] Loss: 10.3937 Val Loss: 16.9793
Epoch [344/1000] Loss: 10.3180 Val Loss: 17.0293
Epoch [345/1000] Loss: 10.3183 Val Loss: 17.0257
Epoch [346/1000] Loss: 10.2919 Val Loss: 17.0245
Epoch [347/1000] Loss: 10.2997 Val Loss: 16.9978
Epoch [348/1000] Loss: 10.2593 Val Loss: 16.9626
Epoch [349/1000] Loss: 10.2792 Val Loss: 16.9258
Epoch [350/1000] Loss: 10.2840 Val Loss: 17.0387
Epoch [351/1000] Loss: 10.1974 Val Loss: 16.9323
Epoch [352/1000] Loss: 10.1854 Val Loss: 17.0010
Epoch [353/1000] Loss: 10.2288 Val Loss: 16.8756
Epoch [354/1000] Loss: 10.2081 Val Loss: 16.9342
Epoch [355/1000] Loss: 10.1842 Val Loss: 16.9146
Epoch [356/1000] Loss: 10.1532 Val Loss: 16.9247
Epoch [357/1000] Loss: 10.1378 Val Loss: 16.9850
Epoch [358/1000] Loss: 10.1308 Val Loss: 16.8964
Epoch [359/1000] Los

Epoch [510/1000] Loss: 8.5853 Val Loss: 16.5264
Epoch [511/1000] Loss: 8.5881 Val Loss: 16.5204
Epoch [512/1000] Loss: 8.5670 Val Loss: 16.5305
Epoch [513/1000] Loss: 8.5394 Val Loss: 16.6086
Epoch [514/1000] Loss: 8.5065 Val Loss: 16.5776
Epoch [515/1000] Loss: 8.4861 Val Loss: 16.5347
Epoch [516/1000] Loss: 8.4915 Val Loss: 16.4889
Epoch [517/1000] Loss: 8.5095 Val Loss: 16.4919
Epoch [518/1000] Loss: 8.5357 Val Loss: 16.4967
Epoch [519/1000] Loss: 8.4828 Val Loss: 16.5130
Epoch [520/1000] Loss: 8.4782 Val Loss: 16.5161
Epoch [521/1000] Loss: 8.4679 Val Loss: 16.5913
Epoch [522/1000] Loss: 8.4751 Val Loss: 16.5289
Epoch [523/1000] Loss: 8.5034 Val Loss: 16.4712
Epoch [524/1000] Loss: 8.4856 Val Loss: 16.5309
Epoch [525/1000] Loss: 8.4317 Val Loss: 16.5086
Epoch [526/1000] Loss: 8.4490 Val Loss: 16.5255
Epoch [527/1000] Loss: 8.3892 Val Loss: 16.5684
Epoch [528/1000] Loss: 8.3796 Val Loss: 16.5204
Epoch [529/1000] Loss: 8.4482 Val Loss: 16.4633
Epoch [530/1000] Loss: 8.3996 Val Loss: 

Epoch [681/1000] Loss: 7.4004 Val Loss: 16.3919
Epoch [682/1000] Loss: 7.4933 Val Loss: 16.3039
Epoch [683/1000] Loss: 7.4663 Val Loss: 16.3079
Epoch [684/1000] Loss: 7.4649 Val Loss: 16.3355
Epoch [685/1000] Loss: 7.5030 Val Loss: 16.3011
Epoch [686/1000] Loss: 7.4647 Val Loss: 16.3232
Epoch [687/1000] Loss: 7.4389 Val Loss: 16.2734
Epoch [688/1000] Loss: 7.4513 Val Loss: 16.3079
Epoch [689/1000] Loss: 7.5222 Val Loss: 16.2878
Epoch [690/1000] Loss: 7.4466 Val Loss: 16.2822
Epoch [691/1000] Loss: 7.4093 Val Loss: 16.2609
Epoch [692/1000] Loss: 7.4034 Val Loss: 16.3247
Epoch [693/1000] Loss: 7.4407 Val Loss: 16.2822
Epoch [694/1000] Loss: 7.4075 Val Loss: 16.3549
Epoch [695/1000] Loss: 7.3761 Val Loss: 16.3159
Epoch [696/1000] Loss: 7.4508 Val Loss: 16.2921
Epoch [697/1000] Loss: 7.3667 Val Loss: 16.2969
Epoch [698/1000] Loss: 7.4004 Val Loss: 16.3279
Epoch [699/1000] Loss: 7.4059 Val Loss: 16.3244
Epoch [700/1000] Loss: 7.3883 Val Loss: 16.2879
Epoch [701/1000] Loss: 7.4235 Val Loss: 

Epoch [852/1000] Loss: 6.8527 Val Loss: 16.3105
Epoch [853/1000] Loss: 6.8467 Val Loss: 16.2719
Epoch [854/1000] Loss: 6.8583 Val Loss: 16.3354
Epoch [855/1000] Loss: 6.8764 Val Loss: 16.2489
Epoch [856/1000] Loss: 6.8315 Val Loss: 16.2655
Epoch [857/1000] Loss: 6.8057 Val Loss: 16.3339
Epoch [858/1000] Loss: 6.7916 Val Loss: 16.2859
Epoch [859/1000] Loss: 6.8611 Val Loss: 16.3036
Epoch [860/1000] Loss: 6.8194 Val Loss: 16.3213
Epoch [861/1000] Loss: 6.8547 Val Loss: 16.3280
Epoch [862/1000] Loss: 6.7937 Val Loss: 16.3267
Epoch [863/1000] Loss: 6.7848 Val Loss: 16.4030
Epoch [864/1000] Loss: 6.8230 Val Loss: 16.3366
Epoch [865/1000] Loss: 6.7934 Val Loss: 16.3207
Epoch [866/1000] Loss: 6.7910 Val Loss: 16.3873
Epoch [867/1000] Loss: 6.7941 Val Loss: 16.3190
Epoch [868/1000] Loss: 6.8073 Val Loss: 16.2928
Epoch [869/1000] Loss: 6.8076 Val Loss: 16.2939
Epoch [870/1000] Loss: 6.7991 Val Loss: 16.3168
Epoch [871/1000] Loss: 6.7326 Val Loss: 16.3307
Epoch [872/1000] Loss: 6.7620 Val Loss: 

In [42]:
#torch.save(model.state_dict(), '3_encder_layers_8heads_12272023_64_hidden_.pt')
#torch.save(model.state_dict(), '02192024_down_sampled_grouping_split.pt')
torch.save(optimizer.state_dict(),'02192024_grouping_split_reweight_opt.pt')
torch.save(torch.Tensor(losses),'02192024_grouping_split_reweight_loss.pt')

In [15]:
len(train_dl)

1034

In [31]:
val_loss_hold = []
for i in losses:
    val_loss_hold.append(i[1])
print(min(val_loss_hold))

28.06430265395232
