In [1]:
import numpy as np
import pandas as pd
import obonet
import random
import torch
import math
from Bio import SeqIO
import Bio.PDB
import urllib.request
import py3Dmol
import pylab
import pickle as pickle
import torch.nn as nn
from torch.nn import Dropout
import torch.nn.functional as F
from torch import Tensor
from torch.nn import Sequential, Linear, ReLU
from torch_geometric.utils import erdos_renyi_graph
from torch_geometric.utils import softmax
from torch_geometric.nn import MessagePassing
from torch_geometric.nn import LayerNorm
from torch_geometric.nn.models import MLP
from torch_geometric.data import Data
import matplotlib.pyplot as plt
import os
from Bio import PDB
from rdkit import Chem
import blosum as bl

In [2]:
def createMask(graph,indicies,num_masked):
    size = graph.x.size()[0]
    protein_mask = [False]*size
    true_mask = [True] * num_masked
    indicies_mask = [False]*(len(indicies) - num_masked)
    design_mask = np.hstack((true_mask,indicies_mask))
    random.shuffle(design_mask)
    
    count = 0
    for i in range(len(protein_mask)):
        if i in indicies:
            protein_mask[i] = design_mask[count]
            count += 1
            
    for i, j in enumerate(protein_mask):
        if j == 1.0:
            protein_mask[i] = True
    
    return protein_mask

In [3]:
graph_list = torch.load('binding_pocket_graphs.pt')

smallest = 5
count = 0
graph_list_clean = []
for entry in graph_list:
    if len(entry.designable_indicies) > smallest:
        graph_list_clean.append(entry)
        count += 1
print(count)
        
        
for i, graph in enumerate(graph_list_clean):
    graph_list_clean[i].mask = createMask(graph,graph.designable_indicies,int(len(graph.designable_indicies)))
    graph_list_clean[i].inv_mask = [not i for i in graph_list_clean[i].mask]

5316


In [4]:
group_assignment = torch.load('group_assignment_30p.pt')
group_size = torch.load('group_size_30p.pt')

In [5]:
for i, graph in enumerate(graph_list_clean):
    graph_list_clean[i].group = group_assignment[graph.label]
    #graph_list_clean[i].weight = 1.0/group_size[graph_list_clean[i].group]

In [6]:
print(random.choice(graph_list))

Data(x=[74, 133], edge_index=[2, 1827], edge_attr=[1827, 114], pos=[73, 3], y=[74, 20], label='3l4w', designable_indicies=[22], mask=[74], inv_mask=[74], group='5nn6')


In [7]:
class Graph_Attn(MessagePassing):
    def __init__(self, in_channels, edge_attr_dim, hidden_channels, out_channels, nheads = 8, attn_dropout = 0.5, mlp_dropout = 0.0, neg_slope = 0.2):
        super(Graph_Attn, self).__init__(node_dim=0, aggr='add')  # 'add' aggregation for summing messages
        
        self.in_channels = float(in_channels)
        self.neg_slope = neg_slope
        self.attn_dropout = attn_dropout
        self.relu = nn.LeakyReLU(negative_slope = neg_slope)
        
        self.nheads = nheads
        self.c = hidden_channels        
        
        self.Wq = Linear(in_channels, nheads*hidden_channels)
        self.Wz = Linear((in_channels+edge_attr_dim), nheads*hidden_channels)
        self.Wv = Linear((in_channels+edge_attr_dim), nheads*hidden_channels)
        self.W0 = Linear(nheads*hidden_channels, out_channels)
        #self.W0 = MLP(in_channels=nheads*out_channels, hidden_channels= 2*nheads*out_channels ,out_channels=out_channels, num_layers=2, norm = 'layer', dropout = mlp_dropout)
        
        self.dropout = nn.Dropout(p = self.attn_dropout)

    def forward(self, x, edge_index, edge_attr)-> Tensor:
        out = self.propagate(edge_index, x=x, edge_attr=edge_attr)
        return self.W0(out.view(-1,self.nheads*self.c))

    def message(self, x_j: Tensor, x_i: Tensor, edge_attr: Tensor, index)-> Tensor:
        rij = torch.cat([x_j , edge_attr], dim=-1)
        qi = self.Wq(x_i).view(-1,self.nheads,self.c)
        zij = self.Wz(rij).view(-1,self.nheads,self.c)
        vij = self.Wv(rij).view(-1,self.nheads,self.c)
        mij = torch.sum(qi * zij * ((1.0/self.in_channels) ** 0.5), dim = -1)
        alphaij = softmax(mij, index)
        alphaij = self.dropout(alphaij)
        msg = vij*alphaij.unsqueeze(-1)
        return  msg
    
x = torch.rand(14,7)
edge_index  = erdos_renyi_graph(14, 0.5)
edge_attr = torch.rand(edge_index.size()[1], 5)

graph = Data(x = x, edge_index = edge_index, edge_attr = edge_attr)
model = Graph_Attn(7,5,10,7)
out = model(graph.x,graph.edge_index,graph.edge_attr)
#print(out.size())

In [8]:
class MLP_Conv_nodes(MessagePassing):
    def __init__(self, in_channels, edge_attr_dim, hidden_channels, out_channels):
        super(MLP_Conv_nodes, self).__init__(aggr='add')  # 'add' aggregation for summing messages
        #self.mlp = nn.Sequential(
        #    nn.Linear(2*in_channels + edge_attr_dim, hidden_channels),
        #    nn.ReLU(),
        #    nn.Linear(hidden_channels, hidden_channels),
        #    nn.ReLU(),
        #    nn.Linear(hidden_channels, out_channels)
        #    
        #)
        self.mlp=MLP(in_channels= 2 * in_channels + edge_attr_dim, hidden_channels= hidden_channels,out_channels=out_channels, num_layers=2, norm = 'layer', dropout = 0.1, act = 'gelu')

    def forward(self, x, edge_index, edge_attr)-> Tensor:
        return self.propagate(edge_index, x=x, edge_attr=edge_attr)

    def message(self, x_j: Tensor, x_i: Tensor, edge_attr: Tensor)-> Tensor:
        return self.mlp(torch.cat([x_j, x_i, edge_attr], dim=-1))

In [9]:
class MLP_Conv_edges(MessagePassing):
    def __init__(self, in_channels, edge_attr_dim, hidden_channels, out_channels):
        super(MLP_Conv_edges, self).__init__(aggr='add')  # 'add' aggregation for summing messages
        #self.mlp = nn.Sequential(
        #    nn.Linear(2*in_channels + edge_attr_dim, hidden_channels),
        #    nn.ReLU(),
        #    nn.Linear(hidden_channels, hidden_channels),
        #    nn.ReLU(),
        #    nn.Linear(hidden_channels, out_channels)
        #    
        #)
        self.mlp=MLP(in_channels= 2 * in_channels + edge_attr_dim, hidden_channels= hidden_channels,out_channels=out_channels, num_layers=2, norm = 'layer', dropout = 0.1, act = 'gelu')

    def forward(self, x, edge_index, edge_attr)-> Tensor:
        return self.edge_updater(edge_index, x=x, edge_attr=edge_attr)

    def edge_update(self, x_j: Tensor, x_i: Tensor, edge_attr: Tensor)-> Tensor:
        return self.mlp(torch.cat([x_j, x_i, edge_attr], dim=-1))

In [10]:
class Encoder_Layer(torch.nn.Module):
    def __init__(self, node_size, edge_size, hidden_size):
        super().__init__()
        self.dropout = nn.Dropout(p = 0.1)
        self.msg1 = Graph_Attn(node_size,edge_size,hidden_size,node_size, nheads = 3)
        #self.msg1 = MLP_Conv_nodes(node_size,edge_size,hidden_size,node_size)
        self.norm_node1 = LayerNorm(node_size, mode = 'node')
        self.norm_node2 = LayerNorm(node_size, mode = 'node')
        self.norm_edge1 = LayerNorm(edge_size, mode = 'node')
        #self.feed_forward = MLP(in_channels=node_size, hidden_channels= 4*node_size,out_channels=node_size, num_layers=2, norm = 'layer', dropout = 0.1)
        self.feed_forward = MLP(in_channels=node_size, hidden_channels= 4*node_size,out_channels=node_size, num_layers=2, norm = 'layer', dropout = 0.1, act = 'gelu')
        self.edge_message = MLP_Conv_edges(node_size,edge_size,int(4*hidden_size),edge_size)
        
    def forward(self, x, edge_index, edge_attr):
        msg = self.msg1(x,edge_index,edge_attr)
        x1 = self.norm_node1(x + self.dropout(msg))
        x2 = self.feed_forward(x1)
        x3 = self.norm_node2(x1 + self.dropout(x2))
        edge_msg = self.edge_message(x3,edge_index,edge_attr)
        edge_attr1 = self.norm_edge1(edge_attr + self.dropout(edge_msg))
        return x3, edge_attr1

In [11]:
class AA_Classifier(torch.nn.Module):
    def __init__(self):
        super(AA_Classifier, self).__init__()
        self.node_feature_size = 133
        self.node_feature_hidden_size = 128
        self.node_feature_size_out = 133
        self.edge_dim = 114
        self.dropout = 0.1
        self.Droput = nn.Dropout(p = self.dropout)
        #self.ff_out = MLP(in_channels=self.node_feature_size, hidden_channels= 64,out_channels=20, num_layers=2, norm = 'layer', dropout = 0.0)
        self.ff_out = Linear(self.node_feature_size, 20)
        #self.ff_out2 = Linear(self.node_feature_size, 11)
        #self.relu = nn.LeakyReLU(negative_slope = 0.2)
        
        self.conv1 = Encoder_Layer(self.node_feature_size,self.edge_dim,self.node_feature_hidden_size)
        self.conv2 = Encoder_Layer(self.node_feature_size,self.edge_dim,self.node_feature_hidden_size)
        self.conv3 = Encoder_Layer(self.node_feature_size,self.edge_dim,self.node_feature_hidden_size)
        self.conv4 = Encoder_Layer(self.node_feature_size,self.edge_dim,self.node_feature_hidden_size)
        #self.conv5 = Encoder_Layer(self.node_feature_size,self.edge_dim,self.node_feature_hidden_size)
        #self.conv6 = Encoder_Layer(self.node_feature_size,self.edge_dim,self.node_feature_hidden_size)
        #self.conv7 = Encoder_Layer(self.node_feature_size,self.edge_dim,self.node_feature_hidden_size)
        #self.conv8 = Encoder_Layer(self.node_feature_size,self.edge_dim,self.node_feature_hidden_size)
        #self.conv9 = Encoder_Layer(self.node_feature_size,self.edge_dim,self.node_feature_hidden_size)
        
        self.softmax = torch.nn.Softmax(dim=1)
        self.tanh = nn.Tanh()
        
    def forward(self,graph):
        x, edge_index, edge_attr = graph.x,graph.edge_index,graph.edge_attr
        
        x1, new_attr = self.conv1(x, edge_index,edge_attr)
        x1, new_attr = self.conv2(x1, edge_index,new_attr)
        x1, new_attr = self.conv3(x1, edge_index,new_attr)
        x1, new_attr = self.conv4(x1, edge_index,new_attr)
        #x1, new_attr = self.conv5(x1, edge_index,new_attr)
        #x1, new_attr = self.conv6(x1, edge_index,new_attr)    
        #x1, new_attr = self.conv7(x1, edge_index,new_attr)
        #x1, new_attr = self.conv8(x1, edge_index,new_attr)
        #x1, new_attr = self.conv9(x1, edge_index,new_attr)
        
        return self.ff_out(x1)#, self.ff_out2(x1[-1])

In [12]:
#graph_list_clean = torch.load('full_graphs_mn_rm_6_11302023.pt')

In [13]:
train_name = torch.load('train_name.pt')
test_name = torch.load('test_name.pt')

train_data = []
val_data = []

for graph in graph_list_clean:
    if graph.label in train_name:
        train_data.append(graph)
    elif graph.label in test_name:
        val_data.append(graph)

In [14]:
print("Train:", len(train_data))
print("Val:", len(val_data))

Train: 3986
Val: 1330


In [15]:
#idxs = torch.load('idxs_11302023.pt')
#train_idxs = idxs[0].detach().numpy()
#val_idxs = idxs[1].detach().numpy()
#train_data = [graph_list_clean[int(i)] for i in train_idxs]
#val_data = [graph_list_clean[int(i)] for i in val_idxs]
#torch.save(train_data, 'train_data_12152023')
#torch.save(val_data, 'val_data_12152023')

In [16]:
#torch.save(train_data, 'train_data_12202023')
#torch.save(val_data, 'val_data_12202023')

In [17]:
from torch_geometric.loader import DataLoader
train_batch_size = 8
full_dl = DataLoader(graph_list,batch_size = 1, shuffle = True)
train_dl = DataLoader(train_data,batch_size = train_batch_size, shuffle = True)
val_dl = DataLoader(val_data,batch_size = 1, shuffle = True)

In [18]:
with open('AA_embeddings.pkl', 'rb') as f:
    AA_embeddings = pickle.load(f)

AA_3_letters = ['ALA','ARG','ASN','ASP','CYS','GLN','GLU','GLY','HIS','ILE','LEU','LYS','MET','PHE','PRO','SER','THR','TRP','TYR','VAL']

In [19]:
lossf = torch.nn.CrossEntropyLoss(reduction = 'mean', label_smoothing = 0.1)
def custom_loss(predict, truth, lossf):
    loss1 = lossf(predict[0], truth[0])
    loss2 = lossf(predict[1], truth[1])
    return (loss1 + loss2)
    

In [20]:
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model = AA_Classifier()
#model.load_state_dict(torch.load('9_encder_layers_12092023.pt'))
model.to(DEVICE) # put on GPU

# Define a loss function (e.g., Mean Squared Error) and an optimizer (e.g., Adam)
criterion = torch.nn.CrossEntropyLoss(reduction = 'sum', label_smoothing = 0.1)
#criterion = torch.nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), betas = (0.9,0.98),eps =1e-9, lr=1e-4, weight_decay=1e-4)

# Training loop
num_epochs = 1000  # Adjust the number of epochs as needed
losses = []
smallest = 39.0

for epoch in range(num_epochs):
    total_loss = 0.0
    val_loss = 0.0
        
    for batch in train_dl:
        model.train()
        inputs = batch.to(DEVICE)
        mask = []
        for i in batch.mask:
            mask += i 
        
        
        # Zero the parameter gradients
        optimizer.zero_grad()
        
        # Forward pass
        with torch.no_grad():
            for i, j in enumerate(mask):
                if j == True:
                    p = random.random()
                    if p < 0.8:
                        inputs.x[i] = torch.zeros(133)
                    elif p >= 0.8 and p < 0.92:
                        inputs.x[i] = AA_embeddings[random.choice(AA_3_letters)]
        

        outputs = model(inputs)
        #outputs, MN = model(inputs)

        # Compute the loss
        loss = criterion(outputs[mask], inputs.y[mask])
        #loss = custom_loss([outputs[mask],MN], [inputs.y[mask],inputs.Kd], lossf)
        
        # Backpropagation and optimization
        loss.backward()
        optimizer.step()
        
        inputs= inputs.to('cpu')
        
        total_loss += loss.item()
        
    for batch in val_dl:
        with torch.no_grad():
            model.eval()
            inputs = batch[0].to(DEVICE)
            
      
            for i, j in enumerate(inputs.mask):
                if j == True:
                    inputs.x[i] = torch.zeros(133)
                    
            outputs = model(inputs)
            #outputs, MN = model(inputs)
            
            loss = criterion(outputs[inputs.mask], inputs.y[inputs.mask])
            #loss = custom_loss([outputs[inputs.mask],MN], [inputs.y[inputs.mask],inputs.Kd], lossf)
        
            inputs= inputs.to('cpu')
        
        val_loss += loss.item()
    
    # Print the average loss for this epoch
    avg_loss = total_loss / (train_batch_size*len(train_dl))
    avg_val_loss = val_loss / len(val_dl)
    print(f'Epoch [{epoch+1}/{num_epochs}] Loss: {avg_loss:.4f} Val Loss: {avg_val_loss:.4f}')
    losses.append([avg_loss,avg_val_loss])
    
    if avg_val_loss < smallest:
        torch.save(model.state_dict(), 'model_weights.pt')
        smallest = avg_val_loss

print('Training complete')


Epoch [1/1000] Loss: 63.0116 Val Loss: 61.4250
Epoch [2/1000] Loss: 59.1496 Val Loss: 58.6495
Epoch [3/1000] Loss: 56.3227 Val Loss: 56.2046
Epoch [4/1000] Loss: 54.3516 Val Loss: 54.8734
Epoch [5/1000] Loss: 52.9044 Val Loss: 53.2981
Epoch [6/1000] Loss: 51.5698 Val Loss: 52.0437
Epoch [7/1000] Loss: 50.5457 Val Loss: 51.3342
Epoch [8/1000] Loss: 49.4939 Val Loss: 50.1952
Epoch [9/1000] Loss: 48.6505 Val Loss: 49.5509
Epoch [10/1000] Loss: 47.7795 Val Loss: 48.5526
Epoch [11/1000] Loss: 46.9801 Val Loss: 47.7358
Epoch [12/1000] Loss: 46.2745 Val Loss: 47.0487
Epoch [13/1000] Loss: 45.6871 Val Loss: 46.3933
Epoch [14/1000] Loss: 44.9222 Val Loss: 45.6543
Epoch [15/1000] Loss: 44.3258 Val Loss: 45.2036
Epoch [16/1000] Loss: 43.7518 Val Loss: 44.6805
Epoch [17/1000] Loss: 43.1555 Val Loss: 44.1447
Epoch [18/1000] Loss: 42.6600 Val Loss: 43.5076
Epoch [19/1000] Loss: 42.1390 Val Loss: 43.0104
Epoch [20/1000] Loss: 41.7516 Val Loss: 42.7103
Epoch [21/1000] Loss: 41.1286 Val Loss: 42.2626
E

Epoch [171/1000] Loss: 23.1099 Val Loss: 30.2915
Epoch [172/1000] Loss: 23.1676 Val Loss: 30.1474
Epoch [173/1000] Loss: 23.0170 Val Loss: 30.1410
Epoch [174/1000] Loss: 22.9844 Val Loss: 30.1811
Epoch [175/1000] Loss: 23.0127 Val Loss: 30.2083
Epoch [176/1000] Loss: 22.8763 Val Loss: 30.0818
Epoch [177/1000] Loss: 22.8616 Val Loss: 30.2758
Epoch [178/1000] Loss: 22.9524 Val Loss: 30.0438
Epoch [179/1000] Loss: 22.9220 Val Loss: 30.1708
Epoch [180/1000] Loss: 22.7496 Val Loss: 30.0705
Epoch [181/1000] Loss: 22.7447 Val Loss: 30.1613
Epoch [182/1000] Loss: 22.8077 Val Loss: 30.1353
Epoch [183/1000] Loss: 22.6641 Val Loss: 30.1368
Epoch [184/1000] Loss: 22.7240 Val Loss: 30.1154
Epoch [185/1000] Loss: 22.6533 Val Loss: 29.9972
Epoch [186/1000] Loss: 22.6394 Val Loss: 29.9983
Epoch [187/1000] Loss: 22.5844 Val Loss: 30.0182
Epoch [188/1000] Loss: 22.5488 Val Loss: 30.0142
Epoch [189/1000] Loss: 22.4246 Val Loss: 30.0838
Epoch [190/1000] Loss: 22.4040 Val Loss: 30.0600
Epoch [191/1000] Los

Epoch [339/1000] Loss: 19.6607 Val Loss: 29.5941
Epoch [340/1000] Loss: 19.6685 Val Loss: 29.6420
Epoch [341/1000] Loss: 19.7432 Val Loss: 29.5717
Epoch [342/1000] Loss: 19.6833 Val Loss: 29.7362
Epoch [343/1000] Loss: 19.7124 Val Loss: 29.6018
Epoch [344/1000] Loss: 19.6513 Val Loss: 29.5897
Epoch [345/1000] Loss: 19.6165 Val Loss: 29.7433
Epoch [346/1000] Loss: 19.5965 Val Loss: 29.6349
Epoch [347/1000] Loss: 19.5647 Val Loss: 29.6732
Epoch [348/1000] Loss: 19.6521 Val Loss: 29.6217
Epoch [349/1000] Loss: 19.5466 Val Loss: 29.6944
Epoch [350/1000] Loss: 19.5649 Val Loss: 29.6070
Epoch [351/1000] Loss: 19.5210 Val Loss: 29.7090
Epoch [352/1000] Loss: 19.5551 Val Loss: 29.6323
Epoch [353/1000] Loss: 19.5076 Val Loss: 29.5540
Epoch [354/1000] Loss: 19.6072 Val Loss: 29.6480
Epoch [355/1000] Loss: 19.5388 Val Loss: 29.7681
Epoch [356/1000] Loss: 19.5834 Val Loss: 29.6729
Epoch [357/1000] Loss: 19.4584 Val Loss: 29.6726
Epoch [358/1000] Loss: 19.4399 Val Loss: 29.6689
Epoch [359/1000] Los

Epoch [507/1000] Loss: 18.4067 Val Loss: 29.4010
Epoch [508/1000] Loss: 18.4109 Val Loss: 29.4453
Epoch [509/1000] Loss: 18.4371 Val Loss: 29.4811
Epoch [510/1000] Loss: 18.3991 Val Loss: 29.3622
Epoch [511/1000] Loss: 18.3871 Val Loss: 29.3036
Epoch [512/1000] Loss: 18.3581 Val Loss: 29.3967
Epoch [513/1000] Loss: 18.3972 Val Loss: 29.5496
Epoch [514/1000] Loss: 18.3879 Val Loss: 29.3920
Epoch [515/1000] Loss: 18.3300 Val Loss: 29.4629
Epoch [516/1000] Loss: 18.3854 Val Loss: 29.5667
Epoch [517/1000] Loss: 18.3856 Val Loss: 29.4216
Epoch [518/1000] Loss: 18.3914 Val Loss: 29.3803
Epoch [519/1000] Loss: 18.4605 Val Loss: 29.4037
Epoch [520/1000] Loss: 18.4186 Val Loss: 29.4472
Epoch [521/1000] Loss: 18.3778 Val Loss: 29.4698
Epoch [522/1000] Loss: 18.3312 Val Loss: 29.4138
Epoch [523/1000] Loss: 18.3399 Val Loss: 29.4556
Epoch [524/1000] Loss: 18.3370 Val Loss: 29.5329
Epoch [525/1000] Loss: 18.2767 Val Loss: 29.5373
Epoch [526/1000] Loss: 18.3478 Val Loss: 29.3744
Epoch [527/1000] Los

Epoch [675/1000] Loss: 17.7576 Val Loss: 29.5364
Epoch [676/1000] Loss: 17.7675 Val Loss: 29.4274
Epoch [677/1000] Loss: 17.7374 Val Loss: 29.3421
Epoch [678/1000] Loss: 17.7956 Val Loss: 29.4276
Epoch [679/1000] Loss: 17.7528 Val Loss: 29.4312
Epoch [680/1000] Loss: 17.7873 Val Loss: 29.3346
Epoch [681/1000] Loss: 17.7127 Val Loss: 29.4427
Epoch [682/1000] Loss: 17.7892 Val Loss: 29.3953
Epoch [683/1000] Loss: 17.7585 Val Loss: 29.4313
Epoch [684/1000] Loss: 17.6908 Val Loss: 29.4820
Epoch [685/1000] Loss: 17.7552 Val Loss: 29.3658
Epoch [686/1000] Loss: 17.7107 Val Loss: 29.4479
Epoch [687/1000] Loss: 17.7568 Val Loss: 29.4288
Epoch [688/1000] Loss: 17.7008 Val Loss: 29.4227
Epoch [689/1000] Loss: 17.7192 Val Loss: 29.3288
Epoch [690/1000] Loss: 17.6952 Val Loss: 29.4615
Epoch [691/1000] Loss: 17.7786 Val Loss: 29.3310
Epoch [692/1000] Loss: 17.7421 Val Loss: 29.3402
Epoch [693/1000] Loss: 17.7720 Val Loss: 29.3282
Epoch [694/1000] Loss: 17.7229 Val Loss: 29.3381
Epoch [695/1000] Los

Epoch [843/1000] Loss: 17.3060 Val Loss: 29.4692
Epoch [844/1000] Loss: 17.3480 Val Loss: 29.4170
Epoch [845/1000] Loss: 17.3086 Val Loss: 29.3855
Epoch [846/1000] Loss: 17.3699 Val Loss: 29.4689
Epoch [847/1000] Loss: 17.3810 Val Loss: 29.3899
Epoch [848/1000] Loss: 17.3443 Val Loss: 29.5274
Epoch [849/1000] Loss: 17.3356 Val Loss: 29.4645
Epoch [850/1000] Loss: 17.3496 Val Loss: 29.3534
Epoch [851/1000] Loss: 17.2887 Val Loss: 29.4570
Epoch [852/1000] Loss: 17.3164 Val Loss: 29.3949
Epoch [853/1000] Loss: 17.2886 Val Loss: 29.4845
Epoch [854/1000] Loss: 17.3042 Val Loss: 29.3013
Epoch [855/1000] Loss: 17.2705 Val Loss: 29.4354
Epoch [856/1000] Loss: 17.3375 Val Loss: 29.4602
Epoch [857/1000] Loss: 17.3679 Val Loss: 29.4262
Epoch [858/1000] Loss: 17.3237 Val Loss: 29.4511
Epoch [859/1000] Loss: 17.3199 Val Loss: 29.4846
Epoch [860/1000] Loss: 17.3537 Val Loss: 29.4953
Epoch [861/1000] Loss: 17.2827 Val Loss: 29.4097
Epoch [862/1000] Loss: 17.2996 Val Loss: 29.4428
Epoch [863/1000] Los

In [21]:
torch.save(torch.Tensor(losses),'loss.pt')