In [1]:
# import torch_geometric
import sys
import pickle
sys.path.insert(0, '../src')

In [2]:
import tqdm
import torch.nn.functional as F
import torchvision.transforms as transforms
import torch.optim as optim
import torchvision
import json
import argparse
# import cv2
import numpy as np
import torch
from torch.autograd import Function
from torchvision import models
import torch.nn as nn
import os
import matplotlib.pyplot as plt
import time
from torchvision.datasets import CocoDetection


from PIL import Image
import pandas as pd

sys.path.insert(0, '../src/data/cocoapi/PythonAPI')
from pycocotools.coco import COCO

In [3]:
# https://github.com/jlevy44/WSI-GTFE/blob/master/notebooks/3_fit_gnn_model.ipynb
import torch, torch.nn as nn
from torch_geometric.nn import GCNConv, GATConv, DeepGraphInfomax, SAGEConv
from torch_geometric.nn import DenseGraphConv
from torch_geometric.utils import to_dense_batch, to_dense_adj, dense_to_sparse
from torch_geometric.nn import GINEConv
from torch_geometric.utils import dropout_adj


class GCNNet(torch.nn.Module):
    def __init__(self, inp_dim, out_dim, hidden_topology=[32,64,128,128], p=0.2, p2=0.0, drop_each=True):
        super(GCNNet, self).__init__()
        self.out_dim=out_dim
        self.convs = nn.ModuleList([GATConv(inp_dim, hidden_topology[0])]+[GATConv(hidden_topology[i],hidden_topology[i+1]) for i in range(len(hidden_topology[:-1]))])
        self.drop_edge = lambda edge_index: dropout_adj(edge_index,p=p2)[0]
        self.dropout = nn.Dropout(p)
        self.fc = nn.Linear(hidden_topology[-1], out_dim)
        self.drop_each=drop_each

    def forward(self, x, edge_index, edge_attr=None, return_attention=False):
        attention_weights=[]
        for conv in self.convs:
            if self.drop_each and self.training: edge_index=self.drop_edge(edge_index)
            x, attention = conv(x, edge_index, edge_attr,return_attention_weights=True)
            x = F.relu(x)
            attention_weights.append(attention)
        if self.training:
            x = self.dropout(x)
        x = self.fc(x)
        if return_attention: return x, attention_weights
        return x

In [5]:
model_gcn.cuda()

GCNNet(
  (convs): ModuleList(
    (0): GATConv(2048, 32, heads=1)
    (1): GATConv(32, 32, heads=1)
  )
  (dropout): Dropout(p=0.2, inplace=False)
  (fc): Linear(in_features=32, out_features=80, bias=True)
)

In [6]:
model_mlp

Sequential(
  (0): Sequential(
    (0): Linear(in_features=2048, out_features=32, bias=True)
    (1): ReLU()
    (2): Dropout(p=0.3, inplace=False)
  )
  (1): Linear(in_features=32, out_features=80, bias=True)
)

In [4]:
# ! jupyter nbextension enable --py widgetsnbextension
# ! python -c "from torchvision import models; model = models.resnet50(pretrained=True)" 


In [5]:
# ! git clone https://github.com/cocodataset/cocoapi.git

In [6]:
import sys

In [7]:
from PIL import Image

In [8]:
# class cocoDataset(torch.utils.data.Dataset):
#     def __init__(self, root, annotation, transforms=None):
#         self.root = root
#         self.transforms = transforms
#         self.coco = COCO(annotation)

In [9]:
# Source: https://medium.com/fullstackai/how-to-train-an-object-detector-with-your-own-coco-dataset-in-pytorch-319e7090da5
class cocoDataset(torch.utils.data.Dataset):
    def __init__(self, root, annotation, transforms=None, target_transform=None):
        self.root = root
        self.transforms = transforms
        self.coco = COCO(annotation)
        self.ids = list(sorted(self.coco.imgs.keys()))
        self.target_transform = target_transform

    
    def __getitem__(self, index):
        # Coco File
        coco = self.coco
        # Image ID
        img_id = self.ids[index]
        # List: get annotation id for img from coco
        ann_ids = coco.getAnnIds(imgIds=img_id)
        # Dictionary: coco annotations for an image
        coco_annotation = coco.loadAnns(ann_ids)
        # path for input image
        path = coco.loadImgs(img_id)[0]['file_name']
        # open the input image with PIL
        img = Image.open(os.path.join(self.root, path)).convert("RGB")
        
        # number of objects in the image
        num_objs = len(coco_annotation)
        
        # Extracting bounding boxes for objects
        # In coco format, bbox = [xmin, ymin, xmax, ymax]
        # In pytorch, input should be [xmin, ymin, xmax, ymax]
        boxes = []
        # Labels for each objet
        labels = []
        # Area of each bounding box
        areas = []
        # IsCrowd: whether or not the object is a crowd of objects
        iscrowd = []
        xy=[]
        subimages=[]
        for i in range(num_objs):
            xmin = coco_annotation[i]['bbox'][0]
            ymin = coco_annotation[i]['bbox'][1]
            xmax = xmin + coco_annotation[i]['bbox'][2]
            ymax = ymin + coco_annotation[i]['bbox'][3]
            bbox=np.array([xmin, ymin, xmax, ymax]).astype(int)
            boxes.append(bbox)
            labels.append(coco_annotation[i]['category_id'])
            areas.append(coco_annotation[i]['area'])
            iscrowd.append(coco_annotation[i]['iscrowd'])
            xy.append([xmin+(xmax-xmin)/2,ymin+(ymax-ymin)/2])
            subimages.append(img.crop(bbox))
            
        boxes = torch.as_tensor(boxes, dtype=torch.float32)
        areas = torch.as_tensor(areas, dtype=torch.float32)
        iscrowd = torch.as_tensor(iscrowd, dtype=torch.float32)
        
        # Annotation is in dictionary format
        my_annotation = {}
        my_annotation['boxes'] = boxes
        my_annotation['labels'] = labels
        my_annotation['image_id'] = img_id
        my_annotation['area'] = areas
        my_annotation['iscrowd'] = iscrowd
        
        imgs=[img]
        
        if self.transforms is not None:
            imgs = [self.transforms(img) for img in subimages]
        if self.target_transform is not None:
            labels = [self.target_transform(label) for label in labels]
#         if self.
        labels = torch.as_tensor(labels, dtype=torch.int64)
            
        imgs=torch.stack(imgs)
        
        xy=torch.tensor(xy)
            
        return imgs, xy, torch.LongTensor([index]*len(labels)), torch.LongTensor(labels)
    
    def __len__(self):
        return len(self.ids)
        

In [10]:
### TODO: add target transform to convert labels to proper labels using id_dict

In [11]:
def load_data(annFile):
    # import json
    with open(annFile, 'r') as f:
        data = json.load(f)
    coco = COCO(annFile)
    # _, data['categories'] = fix_ids(data)
    return coco, data

def fix_ids(data):
    '''
    Takes in 'categories' key of a COCO dataset, returns new IDs for those categories (properly mapped from 0 to len(categories)-1)
    '''
    categories = data['categories']
    id_dict = {}
    for idx, cat in enumerate(categories):
        id_dict[cat['id']] = idx
        cat['id'] = idx
    return id_dict, categories

def drop_null_annotations(coco, data, dataDir, dataType, annFile, overwrite, map_ids=True, save=True, tmpDataDir='data/temp'):
    """
    Takes in a json.load(f) of an annotation file, finds all image ids without an annotation, then drops those images from the file.
    Writes to a new json file without those images.
    """
    # Writes new data to a cleaned json instances file
    fname = tmpDataDir+f"/annotations/clean_instances_{dataType}.json"
    if os.path.isfile(fname) and overwrite==False: print(fname, 'already exists')
    elif os.path.isfile(fname) and overwrite==True: print(f'{fname} already exists, overwriting anyways b/c overwrite=True')
        # if map_ids:
        #     new_data = data.copy()
        #     id_dict = {}
        #     for i, cat in enumerate(new_data['categories']):
        #         id_dict[cat['id']] = i
        #         cat['id'] = i

        # return id_dict


    images_pd = pd.Series(data['images'])
    new_images_pd = images_pd.copy()
    new_data = data.copy()
    new_data['images'] = \
        new_images_pd.loc[~images_pd.apply(lambda x: len(coco.getAnnIds(x['id']))==0)].tolist()
    # sets new_data['images'] to only the list of images with one or more annotations
#     if os.path.isdir(tmpDataF+"/annotations") == False: os.mkdir(dataDir+"/annotations")
    if map_ids: id_dict, _ = fix_ids(new_data)
    print(os.getcwd())

    # print(fname)
    if save:
        try:
            open(fname, 'w')
        except FileNotFoundError:
            os.mkdir(tmpDataDir+'/annotations')
            print('annotation directory created')
        with open(fname, 'w') as f:
            json.dump(new_data, f)
    else:
        print("file saving skipped")
    print("Start images:", len(data['images']))
    print("Images remaining:",len(new_data['images']))
    print("Number of images with no annotations:",len(data['images'])-len(new_data['images']))
    if map_ids: return id_dict
    return new_data

In [12]:


def get_transform():
    custom_transforms = []
    custom_transforms.append(torchvision.transforms.Resize(size=(128, 128)))
    custom_transforms.append(torchvision.transforms.ToTensor())
    return torchvision.transforms.Compose(custom_transforms)

In [13]:
# os.listdir('/datasets/COCO-2017/')

In [14]:
train_dir = '/datasets/COCO-2017/train2017'
val_dir = '/datasets/COCO-2017/val2017'

train_ann = '/datasets/COCO-2017/anno2017/instances_train2017.json'
val_ann = '/datasets/COCO-2017/anno2017/instances_val2017.json'

In [15]:
val_coco, val_data = load_data(val_ann)


loading annotations into memory...
Done (t=0.63s)
creating index...
index created!


In [16]:
pwd

'/home/jdlevy/COGS_185/cogs_185_final_project/notebooks'

In [17]:
id_dict = drop_null_annotations(coco=val_coco, data=val_data, dataDir='/datasets/COCO-2017', dataType='val2017',
                      annFile=val_ann, overwrite=False, tmpDataDir='../data/temp')

../data/temp/annotations/clean_instances_val2017.json already exists
/home/jdlevy/COGS_185/cogs_185_final_project/notebooks
Start images: 5000
Images remaining: 4952
Number of images with no annotations: 48


In [18]:
del val_coco, val_data

In [19]:
def convert_ids(target):
    return id_dict[target]

In [40]:
# # %%capture
# train_coco, train_data = load_data(train_ann)
# %time drop_null_annotations(coco=train_coco, data=train_data, dataDir='/datasets/COCO-2017', dataType='train2017',\
#                             annFile=train_ann, overwrite=True, tmpDataDir='../data/temp');

In [21]:
# del train_coco
# del train_data

# del val_coco
# del val_data

In [22]:
clean_train_ann = '../data/temp/annotations/clean_instances_train2017.json'
clean_val_ann = '../data/temp/annotations/clean_instances_val2017.json'

In [23]:
%time train_set = cocoDataset(root=train_dir,\
                      annotation=clean_train_ann,\
                      transforms=get_transform(),\
                      target_transform=convert_ids)

%time val_set = cocoDataset(root=val_dir,\
                      annotation=clean_val_ann,\
                      transforms=get_transform(),\
                      target_transform=convert_ids)

loading annotations into memory...
Done (t=17.53s)
creating index...
index created!
CPU times: user 16.4 s, sys: 2.38 s, total: 18.8 s
Wall time: 18.7 s
loading annotations into memory...
Done (t=0.57s)
creating index...
index created!
CPU times: user 532 ms, sys: 91.3 ms, total: 623 ms
Wall time: 621 ms


In [24]:
def collate_fn(batch):
    imgs=torch.cat([item[0] for item in batch],dim=0)
    xy=torch.cat([item[1] for item in batch],dim=0)
    idx=torch.cat([item[2] for item in batch]).flatten()
    y=torch.cat([item[3] for item in batch]).flatten()
    return [imgs,xy,idx,y]

In [25]:
# ## Params cell
num_epoch = 10
fname = f'graph_80_{num_epoch}_resnet50'
train_batch_size = 12
k = 5

In [26]:


trainloader = torch.utils.data.DataLoader(train_set,
                                          batch_size=train_batch_size,
                                          shuffle=True,
                                          num_workers=8,
                                          collate_fn=collate_fn)

valloader = torch.utils.data.DataLoader(val_set,
                                        batch_size=train_batch_size,
                                        shuffle=True,
                                        num_workers=8,
                                        collate_fn=collate_fn)



In [27]:
# del train_set, val_set

In [28]:
# Defining our model for GCN embeddings
model = models.resnet50(pretrained=False)
state_dict=torch.load("../src/models/resnet50-19c8e357.pth")
model.load_state_dict(state_dict)
model.fc = nn.Linear(2048, 80)
model.fc=nn.Flatten()

In [29]:
model=model.cuda()

In [30]:
from torch_geometric.data import Data
import tqdm
from torch_cluster import knn_graph

In [31]:
from torch_geometric.data import Data
import tqdm
from torch_cluster import knn_graph

graphs=[]
Idx=[]
Z=[]
XY=[]
Y=[]
with torch.no_grad():
    model.eval()
    for i,(imgs,xy,idx,y) in tqdm.tqdm(enumerate(trainloader),total=len(trainloader.dataset)//trainloader.batch_size):
#         print(imgs.shape,xy.shape,y.shape)
#         edge_index=knn_graph(xy,k=5)
        imgs=imgs.cuda()
        z=model(imgs).cpu()
        Z.append(z)
        XY.append(xy)
        Y.append(y)
        Idx.append(idx)
        
#         graphs.append(Data(x=z,pos=xy,edge_index=edge_index,y=y))
    #     print(graphs[i].x.shape)
#         if i==100: break
        del imgs, z, xy, y, idx

9773it [12:38, 12.89it/s]                            


In [32]:
del trainloader

In [33]:
Z=torch.cat(Z,0)
XY=torch.cat(XY,0)
Y=torch.cat(Y).flatten()
Idx=torch.cat(Idx).flatten().numpy()

In [34]:
train_graphs=[]

for idx in tqdm.tqdm(np.unique(Idx)):
    select=(Idx==idx)
    z=Z[select]
    xy=XY[select]
    y=Y[select]
    edge_index=knn_graph(xy,k=k)
    train_graphs.append(Data(x=z,pos=xy,edge_index=edge_index,y=y))

100%|██████████| 117266/117266 [16:43<00:00, 116.80it/s]


In [35]:
val_graphs=[]
Idx=[]
Z=[]
XY=[]
Y=[]
with torch.no_grad():
    model.eval()
    for i,(imgs,xy,idx,y) in tqdm.tqdm(enumerate(valloader),total=len(valloader.dataset)//valloader.batch_size):
#         print(imgs.shape,xy.shape,y.shape)
#         edge_index=knn_graph(xy,k=5)
        imgs=imgs.cuda()
        z=model(imgs).cpu() # the 2048-dimensional embeddings provided by resnet
        Z.append(z)
        XY.append(xy)
        Y.append(y)
        Idx.append(idx)
        
#         graphs.append(Data(x=z,pos=xy,edge_index=edge_index,y=y))
    #     print(graphs[i].x.shape)
#         if i==100: break
        del imgs

413it [00:34, 12.04it/s]                         


In [36]:
Z=torch.cat(Z,0)
XY=torch.cat(XY,0)
Y=torch.cat(Y).flatten()
Idx=torch.cat(Idx).flatten().numpy()

In [37]:
for idx in tqdm.tqdm(np.unique(Idx)):
    select=(Idx==idx)
    z=Z[select]
    xy=XY[select]
    y=Y[select]
    edge_index=knn_graph(xy,k=k)
    val_graphs.append(Data(x=z,pos=xy,edge_index=edge_index,y=y))

100%|██████████| 4952/4952 [00:13<00:00, 375.93it/s]


In [38]:
del valloader

In [39]:
k

5

In [42]:

mod_select = input('COCO or resnet50 graph embeddings? (coco/resnet50) ')
# if mod_select == 'coco':
    

COCO or resnet50 graph embeddings? (coco/resnet50)resnet50


In [45]:
%%time 
### For writing new train graphs
pickle_name = f'../data/temp/{mod_select}_train_graphs_k_{k}.pkl'
ans = input(f'Are you sure you want to overwrite {pickle_name}? ([y],n)')
if ans == 'y' or ans=='':
    with open(pickle_name, 'wb') as f:
    #     val_graphs = pickle.load(f)
        pickle.dump(train_graphs, f)

Are you sure you want to overwrite ../data/temp/resnet50_train_graphs_k_5.pkl? ([y],n)
CPU times: user 32.7 s, sys: 14.6 s, total: 47.3 s
Wall time: 55.8 s


In [46]:
%%time 
### For writing new val graphs
val_pickle_name = f'../data/temp/{mod_select}_val_graphs_k_{k}.pkl'

ans = input(f'Are you sure you want to overwrite {val_pickle_name}? ([y],n)')
if ans == 'y' or ans =='':
    with open(f'../data/temp/val_graphs_k_{k}.pkl', 'wb') as f:
    #     val_graphs = pickle.load(f)
        pickle.dump(val_graphs, f)

Are you sure you want to overwrite ../data/temp/resnet50_val_graphs_k_5.pkl? ([y],n)
CPU times: user 1.17 s, sys: 532 ms, total: 1.7 s
Wall time: 11.5 s


In [65]:
k

5

In [46]:
%%time 
### For reading train_graphs
with open(f'../data/temp/{mod_select}_train_graphs_k_{k}.pkl', 'rb') as f:
    train_graphs = pickle.load(f)
#     pickle.dump(val_graphs, f)

CPU times: user 43.4 s, sys: 15.8 s, total: 59.2 s
Wall time: 59.2 s


In [66]:
%%time 
### For reading val_graphs
with open(f'../data/temp/{mod_select}_val_graphs_k_{k}.pkl', 'rb') as f:
    val_graphs = pickle.load(f)
#     pickle.dump(val_graphs, f)

CPU times: user 1.83 s, sys: 412 ms, total: 2.24 s
Wall time: 2.23 s


In [47]:
from torch_geometric.data import DataLoader as TG_DataLoader
# https://pytorch.org/docs/stable/data.html#torch.utils.data.Subset
# Batch size should be the same as batch size for dataloader above
graph_trainloader = TG_DataLoader(train_graphs, batch_size=train_batch_size, shuffle=True)

graph_valloader=TG_DataLoader(val_graphs,batch_size=train_batch_size,shuffle=True)#[0]

In [48]:
model_gcn=GCNNet(2048,80, [32]*2)
model_mlp=nn.Sequential(nn.Sequential(nn.Linear(2048,32),nn.ReLU(),nn.Dropout(p=0.3)),nn.Linear(32,80))

In [49]:
# model_gcn.cuda()

In [50]:
# If there are GPUs, choose the first one for computing. Otherwise use CPU.
# device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# print(device)
# If 'cuda:0' is printed, it means GPU is available.

In [51]:
# model_gcn.to(device)

In [54]:
# Params Cell
k = 3
epochs = 10
fname = f'{mod_select}_graph_80_{epochs}_{k}'
batch_size=12
fname

'resnet50_graph_80_10_3'

In [55]:
### TODO: implement edge index as part of dataloader

In [56]:
loss_func = nn.CrossEntropyLoss()
opt = optim.SGD(model_gcn.parameters(), lr=0.0001, momentum=0.9)
avg_losses = []
print_freq = 100

In [None]:
for epoch in range(epochs):
    running_loss = 0.0
    for i, G in enumerate(graph_trainloader, 0):
        y=G.y#.cuda()
        z=G.x#.cuda()
        edge_index=G.edge_index#.cuda()
        
        opt.zero_grad()
        
        outputs = model_gcn(z, edge_index)
        loss = loss_func(outputs, y)
        
        loss.backward()
        
        opt.step()
        
        running_loss += loss.item()
        if i % print_freq == print_freq - 1:
            avg_loss = running_loss / print_freq
            print(f'[epoch: {epoch}, i: {i}] avg mini-batch loss: {avg_loss}')
            avg_losses.append(avg_loss)
            running_loss = 0.0
        
#         y_pred=model_gcn(z,edge_index)
#         print(y_pred.argmax(1))

[epoch: 0, i: 99] avg mini-batch loss: 4.295726537704468
[epoch: 0, i: 199] avg mini-batch loss: 3.933127408027649
[epoch: 0, i: 299] avg mini-batch loss: 3.8372894668579103
[epoch: 0, i: 399] avg mini-batch loss: 3.76384491443634
[epoch: 0, i: 499] avg mini-batch loss: 3.7350661659240725
[epoch: 0, i: 599] avg mini-batch loss: 3.6771505641937257
[epoch: 0, i: 699] avg mini-batch loss: 3.5944129276275634
[epoch: 0, i: 799] avg mini-batch loss: 3.6781405782699585
[epoch: 0, i: 899] avg mini-batch loss: 3.598031578063965
[epoch: 0, i: 999] avg mini-batch loss: 3.5866683602333067
[epoch: 0, i: 1099] avg mini-batch loss: 3.5426287269592285
[epoch: 0, i: 1199] avg mini-batch loss: 3.5633253765106203
[epoch: 0, i: 1299] avg mini-batch loss: 3.4951982736587524
[epoch: 0, i: 1399] avg mini-batch loss: 3.4962286138534546
[epoch: 0, i: 1499] avg mini-batch loss: 3.431639168262482
[epoch: 0, i: 1599] avg mini-batch loss: 3.4474546265602113
[epoch: 0, i: 1699] avg mini-batch loss: 3.49751808881759

[epoch: 1, i: 4299] avg mini-batch loss: 2.964483528137207
[epoch: 1, i: 4399] avg mini-batch loss: 3.0205546498298643
[epoch: 1, i: 4499] avg mini-batch loss: 2.935910210609436
[epoch: 1, i: 4599] avg mini-batch loss: 2.903560700416565
[epoch: 1, i: 4699] avg mini-batch loss: 2.940667595863342
[epoch: 1, i: 4799] avg mini-batch loss: 2.91302551984787
[epoch: 1, i: 4899] avg mini-batch loss: 2.962537581920624
[epoch: 1, i: 4999] avg mini-batch loss: 2.895660493373871
[epoch: 1, i: 5099] avg mini-batch loss: 2.976808066368103
[epoch: 1, i: 5199] avg mini-batch loss: 2.9110999596118927
[epoch: 1, i: 5299] avg mini-batch loss: 2.8825952506065367
[epoch: 1, i: 5399] avg mini-batch loss: 2.9123137307167055
[epoch: 1, i: 5499] avg mini-batch loss: 2.904320688247681
[epoch: 1, i: 5599] avg mini-batch loss: 2.8773550176620484
[epoch: 1, i: 5699] avg mini-batch loss: 2.914371967315674
[epoch: 1, i: 5799] avg mini-batch loss: 2.9443161916732787
[epoch: 1, i: 5899] avg mini-batch loss: 2.85544579

[epoch: 2, i: 8399] avg mini-batch loss: 2.6329364621639253
[epoch: 2, i: 8499] avg mini-batch loss: 2.556089026927948
[epoch: 2, i: 8599] avg mini-batch loss: 2.644288957118988
[epoch: 2, i: 8699] avg mini-batch loss: 2.621056158542633
[epoch: 2, i: 8799] avg mini-batch loss: 2.6299956798553468
[epoch: 2, i: 8899] avg mini-batch loss: 2.6203514432907102
[epoch: 2, i: 8999] avg mini-batch loss: 2.5119445073604583
[epoch: 2, i: 9099] avg mini-batch loss: 2.5613509225845337
[epoch: 2, i: 9199] avg mini-batch loss: 2.5900547921657564
[epoch: 2, i: 9299] avg mini-batch loss: 2.625772867202759
[epoch: 2, i: 9399] avg mini-batch loss: 2.5845332300662993
[epoch: 2, i: 9499] avg mini-batch loss: 2.5618958640098572
[epoch: 2, i: 9599] avg mini-batch loss: 2.57918670296669
[epoch: 2, i: 9699] avg mini-batch loss: 2.554952073097229
[epoch: 3, i: 99] avg mini-batch loss: 2.565951261520386
[epoch: 3, i: 199] avg mini-batch loss: 2.573338257074356
[epoch: 3, i: 299] avg mini-batch loss: 2.5728598964

[epoch: 4, i: 2899] avg mini-batch loss: 2.39104635477066
[epoch: 4, i: 2999] avg mini-batch loss: 2.5146977877616883
[epoch: 4, i: 3099] avg mini-batch loss: 2.4103864312171934
[epoch: 4, i: 3199] avg mini-batch loss: 2.4222772479057313
[epoch: 4, i: 3299] avg mini-batch loss: 2.403122798204422
[epoch: 4, i: 3399] avg mini-batch loss: 2.425963945388794
[epoch: 4, i: 3499] avg mini-batch loss: 2.4612422227859496
[epoch: 4, i: 3599] avg mini-batch loss: 2.442263287305832
[epoch: 4, i: 3699] avg mini-batch loss: 2.377813640832901
[epoch: 4, i: 3799] avg mini-batch loss: 2.3972443640232086
[epoch: 4, i: 3899] avg mini-batch loss: 2.464518733024597
[epoch: 4, i: 3999] avg mini-batch loss: 2.4545074677467347
[epoch: 4, i: 4099] avg mini-batch loss: 2.410927994251251
[epoch: 4, i: 4199] avg mini-batch loss: 2.400344716310501
[epoch: 4, i: 4299] avg mini-batch loss: 2.4906189358234405
[epoch: 4, i: 4399] avg mini-batch loss: 2.440187945365906
[epoch: 4, i: 4499] avg mini-batch loss: 2.4010310

[epoch: 5, i: 6999] avg mini-batch loss: 2.3644348645210265
[epoch: 5, i: 7099] avg mini-batch loss: 2.3898969721794128
[epoch: 5, i: 7199] avg mini-batch loss: 2.3269088459014893
[epoch: 5, i: 7299] avg mini-batch loss: 2.3329247033596037
[epoch: 5, i: 7399] avg mini-batch loss: 2.2856049191951753
[epoch: 5, i: 7499] avg mini-batch loss: 2.304569422006607
[epoch: 5, i: 7599] avg mini-batch loss: 2.403701719045639
[epoch: 5, i: 7699] avg mini-batch loss: 2.380055112838745
[epoch: 5, i: 7799] avg mini-batch loss: 2.3776854836940764
[epoch: 5, i: 7899] avg mini-batch loss: 2.376985387802124
[epoch: 5, i: 7999] avg mini-batch loss: 2.357366272211075
[epoch: 5, i: 8099] avg mini-batch loss: 2.341524577140808
[epoch: 5, i: 8199] avg mini-batch loss: 2.2775875508785246
[epoch: 5, i: 8299] avg mini-batch loss: 2.3706978344917298
[epoch: 5, i: 8399] avg mini-batch loss: 2.3478330636024474
[epoch: 5, i: 8499] avg mini-batch loss: 2.3455031371116637
[epoch: 5, i: 8599] avg mini-batch loss: 2.306

ERROR:root:Internal Python error in the inspect module.
Below is the traceback from this internal error.

ERROR:root:Internal Python error in the inspect module.
Below is the traceback from this internal error.



Traceback (most recent call last):
  File "/datasets/home/41/641/jdlevy/.conda/envs/torchgeo/lib/python3.7/site-packages/IPython/core/interactiveshell.py", line 3418, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "<ipython-input-57-f5e2f5ca8a5c>", line 3, in <module>
    for i, G in enumerate(graph_trainloader, 0):
  File "/datasets/home/41/641/jdlevy/.conda/envs/torchgeo/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 363, in __next__
    data = self._next_data()
  File "/datasets/home/41/641/jdlevy/.conda/envs/torchgeo/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 403, in _next_data
    data = self._dataset_fetcher.fetch(index)  # may raise StopIteration
  File "/datasets/home/41/641/jdlevy/.conda/envs/torchgeo/lib/python3.7/site-packages/torch/utils/data/_utils/fetch.py", line 47, in fetch
    return self.collate_fn(data)
  File "/datasets/home/41/641/jdlevy/.conda/envs/torchgeo/lib/python3.7/site-packages/torch_geomet

ERROR:root:Internal Python error in the inspect module.
Below is the traceback from this internal error.



Traceback (most recent call last):
  File "/datasets/home/41/641/jdlevy/.conda/envs/torchgeo/lib/python3.7/site-packages/IPython/core/interactiveshell.py", line 3418, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "<ipython-input-57-f5e2f5ca8a5c>", line 3, in <module>
    for i, G in enumerate(graph_trainloader, 0):
  File "/datasets/home/41/641/jdlevy/.conda/envs/torchgeo/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 363, in __next__
    data = self._next_data()
  File "/datasets/home/41/641/jdlevy/.conda/envs/torchgeo/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 403, in _next_data
    data = self._dataset_fetcher.fetch(index)  # may raise StopIteration
  File "/datasets/home/41/641/jdlevy/.conda/envs/torchgeo/lib/python3.7/site-packages/torch/utils/data/_utils/fetch.py", line 47, in fetch
    return self.collate_fn(data)
  File "/datasets/home/41/641/jdlevy/.conda/envs/torchgeo/lib/python3.7/site-packages/torch_geomet

In [60]:
PATH = f'../src/models/{fname}.pth'
torch.save(model_gcn.state_dict(), PATH)#, _use_new_zipfile_serialization=False)
print(f'Model saved at {PATH}')

Model saved at ../src/models/resnet50_graph_80_10_3.pth


In [61]:
len(avg_losses)

575

In [62]:
pd.Series(avg_losses).to_csv(f'../data/out/avg_losses_{fname}.csv')

In [74]:
# model.state_dict()

In [44]:
# for epoch in epochs:
#     for G in graph_dataloader:
#         y=G.y.cuda()
#         z=G.x.cuda()
#         edge_index=G.edge_index.cuda()
#         y_pred=model_gcn(z,edge_index)
#         print(y_pred.argmax(1))