In [1]:
!pip install torch-scatter torch-sparse torch-cluster torch-spline-conv torch-geometric -f https://data.pyg.org/whl/torch-1.12.0+cu116.html

import torch
import torch.nn.functional as F
import tqdm

from torch_geometric.data import Batch
from torch_geometric.loader import DataLoader, LinkNeighborLoader
from torch_geometric.nn import GraphSAGE
from torch_geometric.datasets import TUDataset
import torch_geometric.transforms as T

dataset = TUDataset(root='/tmp/NCI1', name='NCI1', transform=T.NormalizeFeatures())
torch.manual_seed(12315)
dataset = dataset.shuffle()
dataset_length = len(dataset)

# split dataset into 3 parts
DA_train = dataset[0:int(0.4*dataset_length)]
D_aux = dataset[int(0.4*dataset_length):int(0.7*dataset_length)]
DA_test = dataset[int(0.7*dataset_length):]

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Looking in links: https://data.pyg.org/whl/torch-1.12.0+cu116.html
Collecting torch-scatter
  Downloading https://data.pyg.org/whl/torch-1.12.0%2Bcu116/torch_scatter-2.0.9-cp37-cp37m-linux_x86_64.whl (8.0 MB)
[K     |████████████████████████████████| 8.0 MB 52.8 MB/s 
[?25hCollecting torch-sparse
  Downloading https://data.pyg.org/whl/torch-1.12.0%2Bcu116/torch_sparse-0.6.15-cp37-cp37m-linux_x86_64.whl (3.5 MB)
[K     |████████████████████████████████| 3.5 MB 54.1 MB/s 
[?25hCollecting torch-cluster
  Downloading https://data.pyg.org/whl/torch-1.12.0%2Bcu116/torch_cluster-1.6.0-cp37-cp37m-linux_x86_64.whl (2.4 MB)
[K     |████████████████████████████████| 2.4 MB 62.3 MB/s 
[?25hCollecting torch-spline-conv
  Downloading https://data.pyg.org/whl/torch-1.12.0%2Bcu116/torch_spline_conv-1.2.1-cp37-cp37m-linux_x86_64.whl (706 kB)
[K     |████████████████████████████████| 706 kB 62.0 MB/

Downloading https://www.chrsmrrs.com/graphkerneldatasets/NCI1.zip
Extracting /tmp/NCI1/NCI1/NCI1.zip
Processing...
Done!


In [2]:
print(len(DA_train))
DA_train_train = DA_train[0:int(0.7*len(DA_train))]
DA_train_val = DA_train[int(0.7*len(DA_train)):]
print(len(DA_train_train))
print(len(DA_train_val))

1644
1150
494


In [3]:
len(D_aux)

1233

In [4]:
len(DA_test)

1233

In [5]:
train_data = Batch.from_data_list(DA_train_train)
test_data = Batch.from_data_list(DA_train_val)

train_loader = LinkNeighborLoader(train_data, batch_size=2048, shuffle=True,
                            neg_sampling_ratio=0.5, num_neighbors=[10,10],
                            num_workers=2, persistent_workers=True)

test_loader = LinkNeighborLoader(test_data, batch_size=2048, shuffle=False,
                            neg_sampling_ratio=0.5, num_neighbors=[10,10],
                            num_workers=2, persistent_workers=True)

In [6]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = GraphSAGE(
            in_channels=DA_train.num_features,
            hidden_channels=192*2,
            num_layers=3,
            out_channels=192
        ).to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=0.005)

In [7]:
def train():
    model.train()
    
    total_loss = total_examples = 0
    
    for data in tqdm.tqdm(train_loader):
        data = data.to(device)
        optimizer.zero_grad()
        
        h = model(data.x, data.edge_index)
        
        h_src = h[data.edge_label_index[0]]
        h_dst = h[data.edge_label_index[1]]
        
        link_pred = (h_src * h_dst).sum(dim=-1)
        
        loss = F.binary_cross_entropy_with_logits(link_pred, data.edge_label)
        
        loss.backward()
        optimizer.step()
        
        total_loss += float(loss) * link_pred.numel()
        total_examples += link_pred.numel()
        
    return total_loss / total_examples

In [8]:
def test():
    model.eval()
    
    total_loss = total_examples = 0
    
    for data in tqdm.tqdm(test_loader):
        data = data.to(device)
        # optimizer.zero_grad()
        
        h = model(data.x, data.edge_index)
        
        h_src = h[data.edge_label_index[0]]
        h_dst = h[data.edge_label_index[1]]
        
        link_pred = (h_src * h_dst).sum(dim=-1)
        
        loss = F.binary_cross_entropy_with_logits(link_pred, data.edge_label)
        
#         loss.backward()
#         optimizer.step()
        
        total_loss += float(loss) * link_pred.numel()
        total_examples += link_pred.numel()
        
    return total_loss / total_examples

In [9]:
for epoch in range(1,100):
    train_loss = train()
    test_loss = test()
    print(f'Epoch: {epoch:02d}, Train loss: {train_loss:.4f}, Test loss: {test_loss:.4f}')

100%|██████████| 37/37 [00:04<00:00,  8.53it/s]
100%|██████████| 16/16 [00:00<00:00, 27.62it/s]


Epoch: 01, Train loss: 0.6650, Test loss: 0.4568


100%|██████████| 37/37 [00:01<00:00, 25.57it/s]
100%|██████████| 16/16 [00:00<00:00, 31.45it/s]


Epoch: 02, Train loss: 0.4460, Test loss: 0.4176


100%|██████████| 37/37 [00:01<00:00, 24.61it/s]
100%|██████████| 16/16 [00:00<00:00, 34.20it/s]


Epoch: 03, Train loss: 0.3971, Test loss: 0.3525


100%|██████████| 37/37 [00:01<00:00, 25.59it/s]
100%|██████████| 16/16 [00:00<00:00, 35.79it/s]


Epoch: 04, Train loss: 0.3759, Test loss: 0.3293


100%|██████████| 37/37 [00:01<00:00, 24.04it/s]
100%|██████████| 16/16 [00:00<00:00, 35.74it/s]


Epoch: 05, Train loss: 0.3595, Test loss: 0.3099


100%|██████████| 37/37 [00:01<00:00, 25.77it/s]
100%|██████████| 16/16 [00:00<00:00, 24.76it/s]


Epoch: 06, Train loss: 0.3475, Test loss: 0.3130


100%|██████████| 37/37 [00:01<00:00, 22.60it/s]
100%|██████████| 16/16 [00:00<00:00, 32.36it/s]


Epoch: 07, Train loss: 0.3495, Test loss: 0.3134


100%|██████████| 37/37 [00:01<00:00, 25.54it/s]
100%|██████████| 16/16 [00:00<00:00, 33.93it/s]


Epoch: 08, Train loss: 0.3396, Test loss: 0.3124


100%|██████████| 37/37 [00:01<00:00, 25.07it/s]
100%|██████████| 16/16 [00:00<00:00, 31.84it/s]


Epoch: 09, Train loss: 0.3360, Test loss: 0.3120


100%|██████████| 37/37 [00:01<00:00, 25.48it/s]
100%|██████████| 16/16 [00:00<00:00, 38.37it/s]


Epoch: 10, Train loss: 0.3365, Test loss: 0.3050


100%|██████████| 37/37 [00:01<00:00, 24.68it/s]
100%|██████████| 16/16 [00:00<00:00, 31.40it/s]


Epoch: 11, Train loss: 0.3342, Test loss: 0.2992


100%|██████████| 37/37 [00:01<00:00, 24.41it/s]
100%|██████████| 16/16 [00:00<00:00, 34.41it/s]


Epoch: 12, Train loss: 0.3352, Test loss: 0.3074


100%|██████████| 37/37 [00:01<00:00, 24.99it/s]
100%|██████████| 16/16 [00:00<00:00, 36.97it/s]


Epoch: 13, Train loss: 0.3320, Test loss: 0.3001


100%|██████████| 37/37 [00:01<00:00, 24.87it/s]
100%|██████████| 16/16 [00:00<00:00, 32.68it/s]


Epoch: 14, Train loss: 0.3306, Test loss: 0.3087


100%|██████████| 37/37 [00:01<00:00, 24.22it/s]
100%|██████████| 16/16 [00:00<00:00, 39.92it/s]


Epoch: 15, Train loss: 0.3275, Test loss: 0.3004


100%|██████████| 37/37 [00:01<00:00, 24.62it/s]
100%|██████████| 16/16 [00:00<00:00, 27.31it/s]


Epoch: 16, Train loss: 0.3289, Test loss: 0.3057


100%|██████████| 37/37 [00:02<00:00, 15.06it/s]
100%|██████████| 16/16 [00:00<00:00, 21.60it/s]


Epoch: 17, Train loss: 0.3282, Test loss: 0.2979


100%|██████████| 37/37 [00:02<00:00, 15.29it/s]
100%|██████████| 16/16 [00:00<00:00, 31.16it/s]


Epoch: 18, Train loss: 0.3234, Test loss: 0.3032


100%|██████████| 37/37 [00:01<00:00, 24.21it/s]
100%|██████████| 16/16 [00:00<00:00, 38.85it/s]


Epoch: 19, Train loss: 0.3231, Test loss: 0.3010


100%|██████████| 37/37 [00:01<00:00, 24.73it/s]
100%|██████████| 16/16 [00:00<00:00, 32.56it/s]


Epoch: 20, Train loss: 0.3209, Test loss: 0.2943


100%|██████████| 37/37 [00:01<00:00, 24.82it/s]
100%|██████████| 16/16 [00:00<00:00, 31.64it/s]


Epoch: 21, Train loss: 0.3191, Test loss: 0.3000


100%|██████████| 37/37 [00:01<00:00, 24.60it/s]
100%|██████████| 16/16 [00:00<00:00, 38.58it/s]


Epoch: 22, Train loss: 0.3172, Test loss: 0.3040


100%|██████████| 37/37 [00:01<00:00, 24.37it/s]
100%|██████████| 16/16 [00:00<00:00, 31.58it/s]


Epoch: 23, Train loss: 0.3188, Test loss: 0.2945


100%|██████████| 37/37 [00:01<00:00, 24.58it/s]
100%|██████████| 16/16 [00:00<00:00, 35.34it/s]


Epoch: 24, Train loss: 0.3147, Test loss: 0.2952


100%|██████████| 37/37 [00:01<00:00, 23.40it/s]
100%|██████████| 16/16 [00:00<00:00, 34.69it/s]


Epoch: 25, Train loss: 0.3124, Test loss: 0.3039


100%|██████████| 37/37 [00:01<00:00, 25.36it/s]
100%|██████████| 16/16 [00:00<00:00, 31.12it/s]


Epoch: 26, Train loss: 0.3154, Test loss: 0.2918


100%|██████████| 37/37 [00:01<00:00, 23.84it/s]
100%|██████████| 16/16 [00:00<00:00, 32.05it/s]


Epoch: 27, Train loss: 0.3153, Test loss: 0.3013


100%|██████████| 37/37 [00:01<00:00, 25.59it/s]
100%|██████████| 16/16 [00:00<00:00, 35.96it/s]


Epoch: 28, Train loss: 0.3110, Test loss: 0.2980


100%|██████████| 37/37 [00:01<00:00, 23.06it/s]
100%|██████████| 16/16 [00:00<00:00, 34.40it/s]


Epoch: 29, Train loss: 0.3138, Test loss: 0.2909


100%|██████████| 37/37 [00:01<00:00, 24.17it/s]
100%|██████████| 16/16 [00:00<00:00, 31.90it/s]


Epoch: 30, Train loss: 0.3134, Test loss: 0.2883


100%|██████████| 37/37 [00:01<00:00, 23.98it/s]
100%|██████████| 16/16 [00:00<00:00, 34.79it/s]


Epoch: 31, Train loss: 0.3097, Test loss: 0.2886


100%|██████████| 37/37 [00:01<00:00, 25.59it/s]
100%|██████████| 16/16 [00:00<00:00, 33.55it/s]


Epoch: 32, Train loss: 0.3114, Test loss: 0.2881


100%|██████████| 37/37 [00:01<00:00, 23.81it/s]
100%|██████████| 16/16 [00:00<00:00, 33.63it/s]


Epoch: 33, Train loss: 0.3085, Test loss: 0.2974


100%|██████████| 37/37 [00:01<00:00, 25.11it/s]
100%|██████████| 16/16 [00:00<00:00, 31.64it/s]


Epoch: 34, Train loss: 0.3120, Test loss: 0.2928


100%|██████████| 37/37 [00:01<00:00, 24.16it/s]
100%|██████████| 16/16 [00:00<00:00, 34.29it/s]


Epoch: 35, Train loss: 0.3077, Test loss: 0.2937


100%|██████████| 37/37 [00:01<00:00, 25.02it/s]
100%|██████████| 16/16 [00:00<00:00, 31.74it/s]


Epoch: 36, Train loss: 0.3070, Test loss: 0.2893


100%|██████████| 37/37 [00:01<00:00, 24.04it/s]
100%|██████████| 16/16 [00:00<00:00, 40.63it/s]


Epoch: 37, Train loss: 0.3066, Test loss: 0.2873


100%|██████████| 37/37 [00:01<00:00, 25.19it/s]
100%|██████████| 16/16 [00:00<00:00, 32.21it/s]


Epoch: 38, Train loss: 0.3063, Test loss: 0.2969


100%|██████████| 37/37 [00:01<00:00, 23.54it/s]
100%|██████████| 16/16 [00:00<00:00, 35.23it/s]


Epoch: 39, Train loss: 0.3075, Test loss: 0.2901


100%|██████████| 37/37 [00:01<00:00, 25.09it/s]
100%|██████████| 16/16 [00:00<00:00, 36.07it/s]


Epoch: 40, Train loss: 0.3061, Test loss: 0.2972


100%|██████████| 37/37 [00:01<00:00, 24.02it/s]
100%|██████████| 16/16 [00:00<00:00, 32.41it/s]


Epoch: 41, Train loss: 0.3073, Test loss: 0.2904


100%|██████████| 37/37 [00:01<00:00, 25.98it/s]
100%|██████████| 16/16 [00:00<00:00, 38.99it/s]


Epoch: 42, Train loss: 0.3046, Test loss: 0.2854


100%|██████████| 37/37 [00:01<00:00, 24.36it/s]
100%|██████████| 16/16 [00:00<00:00, 31.70it/s]


Epoch: 43, Train loss: 0.3057, Test loss: 0.2935


100%|██████████| 37/37 [00:01<00:00, 23.84it/s]
100%|██████████| 16/16 [00:00<00:00, 20.09it/s]


Epoch: 44, Train loss: 0.3024, Test loss: 0.2928


100%|██████████| 37/37 [00:02<00:00, 16.55it/s]
100%|██████████| 16/16 [00:00<00:00, 23.27it/s]


Epoch: 45, Train loss: 0.3038, Test loss: 0.2885


100%|██████████| 37/37 [00:01<00:00, 24.23it/s]
100%|██████████| 16/16 [00:00<00:00, 35.48it/s]


Epoch: 46, Train loss: 0.3030, Test loss: 0.2906


100%|██████████| 37/37 [00:01<00:00, 24.29it/s]
100%|██████████| 16/16 [00:00<00:00, 36.09it/s]


Epoch: 47, Train loss: 0.3044, Test loss: 0.2857


100%|██████████| 37/37 [00:01<00:00, 24.71it/s]
100%|██████████| 16/16 [00:00<00:00, 28.60it/s]


Epoch: 48, Train loss: 0.3059, Test loss: 0.2904


100%|██████████| 37/37 [00:01<00:00, 24.44it/s]
100%|██████████| 16/16 [00:00<00:00, 34.50it/s]


Epoch: 49, Train loss: 0.3014, Test loss: 0.2826


100%|██████████| 37/37 [00:01<00:00, 24.87it/s]
100%|██████████| 16/16 [00:00<00:00, 34.32it/s]


Epoch: 50, Train loss: 0.3020, Test loss: 0.2825


100%|██████████| 37/37 [00:01<00:00, 24.73it/s]
100%|██████████| 16/16 [00:00<00:00, 33.04it/s]


Epoch: 51, Train loss: 0.3030, Test loss: 0.2767


100%|██████████| 37/37 [00:01<00:00, 24.54it/s]
100%|██████████| 16/16 [00:00<00:00, 35.72it/s]


Epoch: 52, Train loss: 0.3028, Test loss: 0.2906


100%|██████████| 37/37 [00:01<00:00, 24.75it/s]
100%|██████████| 16/16 [00:00<00:00, 32.58it/s]


Epoch: 53, Train loss: 0.3027, Test loss: 0.2913


100%|██████████| 37/37 [00:01<00:00, 24.33it/s]
100%|██████████| 16/16 [00:00<00:00, 34.18it/s]


Epoch: 54, Train loss: 0.3004, Test loss: 0.2973


100%|██████████| 37/37 [00:01<00:00, 24.72it/s]
100%|██████████| 16/16 [00:00<00:00, 35.99it/s]


Epoch: 55, Train loss: 0.2971, Test loss: 0.2878


100%|██████████| 37/37 [00:01<00:00, 24.45it/s]
100%|██████████| 16/16 [00:00<00:00, 31.70it/s]


Epoch: 56, Train loss: 0.3029, Test loss: 0.2856


100%|██████████| 37/37 [00:01<00:00, 24.43it/s]
100%|██████████| 16/16 [00:00<00:00, 34.98it/s]


Epoch: 57, Train loss: 0.3013, Test loss: 0.2808


100%|██████████| 37/37 [00:01<00:00, 24.43it/s]
100%|██████████| 16/16 [00:00<00:00, 35.84it/s]


Epoch: 58, Train loss: 0.3026, Test loss: 0.2910


100%|██████████| 37/37 [00:01<00:00, 24.17it/s]
100%|██████████| 16/16 [00:00<00:00, 31.17it/s]


Epoch: 59, Train loss: 0.3019, Test loss: 0.2864


100%|██████████| 37/37 [00:01<00:00, 25.21it/s]
100%|██████████| 16/16 [00:00<00:00, 40.65it/s]


Epoch: 60, Train loss: 0.3004, Test loss: 0.2856


100%|██████████| 37/37 [00:01<00:00, 23.87it/s]
100%|██████████| 16/16 [00:00<00:00, 32.67it/s]


Epoch: 61, Train loss: 0.3008, Test loss: 0.2870


100%|██████████| 37/37 [00:01<00:00, 25.19it/s]
100%|██████████| 16/16 [00:00<00:00, 34.83it/s]


Epoch: 62, Train loss: 0.3003, Test loss: 0.2828


100%|██████████| 37/37 [00:01<00:00, 23.75it/s]
100%|██████████| 16/16 [00:00<00:00, 30.89it/s]


Epoch: 63, Train loss: 0.2989, Test loss: 0.2887


100%|██████████| 37/37 [00:01<00:00, 25.22it/s]
100%|██████████| 16/16 [00:00<00:00, 34.65it/s]


Epoch: 64, Train loss: 0.3002, Test loss: 0.2900


100%|██████████| 37/37 [00:01<00:00, 23.78it/s]
100%|██████████| 16/16 [00:00<00:00, 34.86it/s]


Epoch: 65, Train loss: 0.2987, Test loss: 0.2830


100%|██████████| 37/37 [00:01<00:00, 26.04it/s]
100%|██████████| 16/16 [00:00<00:00, 30.52it/s]


Epoch: 66, Train loss: 0.2990, Test loss: 0.2777


100%|██████████| 37/37 [00:01<00:00, 23.48it/s]
100%|██████████| 16/16 [00:00<00:00, 39.98it/s]


Epoch: 67, Train loss: 0.2969, Test loss: 0.2844


100%|██████████| 37/37 [00:01<00:00, 24.87it/s]
100%|██████████| 16/16 [00:00<00:00, 32.52it/s]


Epoch: 68, Train loss: 0.2975, Test loss: 0.2822


100%|██████████| 37/37 [00:01<00:00, 24.49it/s]
100%|██████████| 16/16 [00:00<00:00, 38.49it/s]


Epoch: 69, Train loss: 0.2955, Test loss: 0.2809


100%|██████████| 37/37 [00:01<00:00, 25.23it/s]
100%|██████████| 16/16 [00:00<00:00, 33.00it/s]


Epoch: 70, Train loss: 0.2998, Test loss: 0.2780


100%|██████████| 37/37 [00:01<00:00, 24.19it/s]
100%|██████████| 16/16 [00:00<00:00, 31.78it/s]


Epoch: 71, Train loss: 0.2950, Test loss: 0.2851


100%|██████████| 37/37 [00:01<00:00, 24.92it/s]
100%|██████████| 16/16 [00:00<00:00, 34.53it/s]


Epoch: 72, Train loss: 0.2967, Test loss: 0.2872


100%|██████████| 37/37 [00:02<00:00, 15.31it/s]
100%|██████████| 16/16 [00:00<00:00, 21.69it/s]


Epoch: 73, Train loss: 0.2950, Test loss: 0.2835


100%|██████████| 37/37 [00:01<00:00, 23.40it/s]
100%|██████████| 16/16 [00:00<00:00, 31.29it/s]


Epoch: 74, Train loss: 0.2980, Test loss: 0.2843


100%|██████████| 37/37 [00:01<00:00, 23.92it/s]
100%|██████████| 16/16 [00:00<00:00, 34.38it/s]


Epoch: 75, Train loss: 0.2972, Test loss: 0.2803


100%|██████████| 37/37 [00:01<00:00, 24.97it/s]
100%|██████████| 16/16 [00:00<00:00, 34.39it/s]


Epoch: 76, Train loss: 0.2952, Test loss: 0.2761


100%|██████████| 37/37 [00:01<00:00, 24.09it/s]
100%|██████████| 16/16 [00:00<00:00, 34.42it/s]


Epoch: 77, Train loss: 0.2959, Test loss: 0.2800


100%|██████████| 37/37 [00:01<00:00, 24.94it/s]
100%|██████████| 16/16 [00:00<00:00, 32.19it/s]


Epoch: 78, Train loss: 0.2948, Test loss: 0.2794


100%|██████████| 37/37 [00:01<00:00, 24.10it/s]
100%|██████████| 16/16 [00:00<00:00, 31.68it/s]


Epoch: 79, Train loss: 0.2953, Test loss: 0.2769


100%|██████████| 37/37 [00:01<00:00, 25.17it/s]
100%|██████████| 16/16 [00:00<00:00, 34.65it/s]


Epoch: 80, Train loss: 0.2965, Test loss: 0.2771


100%|██████████| 37/37 [00:01<00:00, 23.21it/s]
100%|██████████| 16/16 [00:00<00:00, 31.54it/s]


Epoch: 81, Train loss: 0.2934, Test loss: 0.2854


100%|██████████| 37/37 [00:01<00:00, 25.06it/s]
100%|██████████| 16/16 [00:00<00:00, 38.85it/s]


Epoch: 82, Train loss: 0.2939, Test loss: 0.2832


100%|██████████| 37/37 [00:01<00:00, 23.54it/s]
100%|██████████| 16/16 [00:00<00:00, 31.92it/s]


Epoch: 83, Train loss: 0.2934, Test loss: 0.2855


100%|██████████| 37/37 [00:01<00:00, 25.17it/s]
100%|██████████| 16/16 [00:00<00:00, 35.27it/s]


Epoch: 84, Train loss: 0.2941, Test loss: 0.2891


100%|██████████| 37/37 [00:01<00:00, 23.65it/s]
100%|██████████| 16/16 [00:00<00:00, 35.70it/s]


Epoch: 85, Train loss: 0.2935, Test loss: 0.2818


100%|██████████| 37/37 [00:01<00:00, 25.12it/s]
100%|██████████| 16/16 [00:00<00:00, 31.14it/s]


Epoch: 86, Train loss: 0.2950, Test loss: 0.2831


100%|██████████| 37/37 [00:01<00:00, 23.97it/s]
100%|██████████| 16/16 [00:00<00:00, 39.77it/s]


Epoch: 87, Train loss: 0.2945, Test loss: 0.2842


100%|██████████| 37/37 [00:01<00:00, 25.03it/s]
100%|██████████| 16/16 [00:00<00:00, 32.26it/s]


Epoch: 88, Train loss: 0.2951, Test loss: 0.2828


100%|██████████| 37/37 [00:01<00:00, 23.67it/s]
100%|██████████| 16/16 [00:00<00:00, 35.33it/s]


Epoch: 89, Train loss: 0.2939, Test loss: 0.2795


100%|██████████| 37/37 [00:01<00:00, 24.72it/s]
100%|██████████| 16/16 [00:00<00:00, 31.16it/s]


Epoch: 90, Train loss: 0.2932, Test loss: 0.2837


100%|██████████| 37/37 [00:01<00:00, 23.18it/s]
100%|██████████| 16/16 [00:00<00:00, 39.44it/s]


Epoch: 91, Train loss: 0.2940, Test loss: 0.2822


100%|██████████| 37/37 [00:01<00:00, 24.06it/s]
100%|██████████| 16/16 [00:00<00:00, 30.49it/s]


Epoch: 92, Train loss: 0.2940, Test loss: 0.2826


100%|██████████| 37/37 [00:01<00:00, 23.58it/s]
100%|██████████| 16/16 [00:00<00:00, 30.95it/s]


Epoch: 93, Train loss: 0.2959, Test loss: 0.2774


100%|██████████| 37/37 [00:01<00:00, 24.94it/s]
100%|██████████| 16/16 [00:00<00:00, 38.71it/s]


Epoch: 94, Train loss: 0.2920, Test loss: 0.2761


100%|██████████| 37/37 [00:01<00:00, 22.86it/s]
100%|██████████| 16/16 [00:00<00:00, 31.75it/s]


Epoch: 95, Train loss: 0.2948, Test loss: 0.2761


100%|██████████| 37/37 [00:01<00:00, 24.11it/s]
100%|██████████| 16/16 [00:00<00:00, 38.03it/s]


Epoch: 96, Train loss: 0.2926, Test loss: 0.2850


100%|██████████| 37/37 [00:01<00:00, 22.56it/s]
100%|██████████| 16/16 [00:00<00:00, 31.61it/s]


Epoch: 97, Train loss: 0.2939, Test loss: 0.2865


100%|██████████| 37/37 [00:01<00:00, 24.14it/s]
100%|██████████| 16/16 [00:00<00:00, 31.98it/s]


Epoch: 98, Train loss: 0.2922, Test loss: 0.2771


100%|██████████| 37/37 [00:02<00:00, 17.30it/s]
100%|██████████| 16/16 [00:00<00:00, 18.83it/s]

Epoch: 99, Train loss: 0.2923, Test loss: 0.2803





# save the model to local

In [10]:
data_save_path = "NCI_model.pt"
torch.save(model, data_save_path)