In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
root_dir = '/content/drive/MyDrive/Models_exp6'

In [3]:
import os
import multiprocessing
import numpy as np
import cv2
import time
import shutil

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
from torchvision import transforms

In [4]:
if torch.cuda.is_available():  
  dev = "cuda:0" 
else:  
  dev = "cpu"  

device = torch.device(dev)
print(device)

cuda:0


In [5]:
!nvidia-smi

Mon Apr 26 17:28:09 2021       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 465.19.01    Driver Version: 460.32.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla V100-SXM2...  Off  | 00000000:00:04.0 Off |                    0 |
| N/A   36C    P0    25W / 300W |      2MiB / 16160MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [6]:
#https://towardsdatascience.com/implementing-the-new-state-of-the-art-mish-activation-with-2-lines-of-code-in-pytorch-e7ef438a5ee7
def mish(x): 
  return (x * torch.tanh(nn.functional.softplus(x)))

In [7]:
#https://machinelearningmastery.com/pytorch-tutorial-develop-deep-learning-models/
class Model_3d(nn.Module):
    '''
    A model to use on 3d objects
    '''
    def __init__(self, fc_nodes, no_views, no_classes):
      super(Model_3d, self).__init__()
      self.resnet = torchvision.models.resnet50(pretrained=True) #https://pytorch.org/tutorials/beginner/transfer_learning_tutorial.html

      num_features = self.resnet.fc.out_features

      self.fcs = nn.ModuleList()

      for idx in range(no_views):
        self.fcs.append(nn.Linear(num_features, fc_nodes))

      first_dense_count = fc_nodes * no_views

      self.dense_1 = nn.Linear(first_dense_count, 1028)
      self.dense_2 = nn.Linear(1028, no_classes)

    
    def forward(self, x):
      features = []

      for i, fc in enumerate(self.fcs):
        sample = x[:, i, ...]
        r_out = self.resnet(sample)

        fc_out = fc(r_out)
        features.append(fc_out)
      
      stack = torch.stack(features, 1)

      reshaped = stack.view([x.shape[0], -1])

      mish_out = mish(reshaped)

      x = self.dense_1(mish_out)
      x = self.dense_2(mish(x))


      return torch.nn.functional.softmax(x)

In [8]:
class Date_3d_Cached(torch.utils.data.Dataset):
    def __init__(self, x, y):
      self.x = torch.from_numpy(x)
      self.y = torch.from_numpy(y)
    
    def __len__(self):
      return self.x.shape[0]

    def __getitem__(self, idx):
      return self.x[idx] /255.0, self.y[idx]

In [9]:
views = ['bottom', 'side_1', 'side_2', 'side_3', 'side_4', 'top']
x = np.load('/content/drive/MyDrive/Cache/x.npy')
y = np.load('/content/drive/MyDrive/Cache/y.npy')

training_data = Date_3d_Cached(x, y)

In [10]:
#90/10 train val split
train_len = int(len(training_data) * .9)
val_len = len(training_data) - train_len

In [11]:
train_set, val_set = torch.utils.data.random_split(training_data, [train_len, val_len],torch.Generator().manual_seed(42))

In [12]:
batch_size = 16

In [13]:
model = Model_3d(1024, len(views), 10)
model.to(device)
print(device)
training_loader = torch.utils.data.DataLoader(train_set, batch_size= batch_size, shuffle=True, num_workers=os.cpu_count())
val_loader = torch.utils.data.DataLoader(val_set, batch_size= batch_size, shuffle=True, num_workers=os.cpu_count())

Downloading: "https://download.pytorch.org/models/resnet50-19c8e357.pth" to /root/.cache/torch/hub/checkpoints/resnet50-19c8e357.pth


HBox(children=(FloatProgress(value=0.0, max=102502400.0), HTML(value='')))


cuda:0


In [14]:
print(os.cpu_count())
print(train_len)
print(val_len)

2
3591
400


In [15]:
#https://medium.com/analytics-vidhya/saving-and-loading-your-model-to-resume-training-in-pytorch-cb687352fa61
def save_ckp(state, model_path):
    torch.save(state, model_path)

In [20]:
#https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

if not os.path.exists(root_dir):
    os.makedirs(root_dir)


low_val_loss = 30000000
best_acc = -1
epoch = 0

val_loss_path = None
acc_path = None
#https://stackoverflow.com/questions/8078330/csv-writing-within-loop
import csv
with open(os.path.join(root_dir, 'logs.csv'), 'w') as file_csv:
  writer=csv.writer(file_csv, delimiter=',',lineterminator='\n',)
  writer.writerow(['epoch', 'loss', 'acc', 'val_loss', 'val_acc'])


  for epoch in range(1, 76):
      row = [epoch]
      is_best = False
      running_loss = 0.0
      correct = 0
      total = 0

      model.train()

      t0 = time.time()
      for i, data in enumerate(training_loader):
          # get the inputs; data is a list of [inputs, labels]
          inputs, labels = data

          for this_view in range(inputs.shape[1]):
            if np.random.randint(0, 4) == 0:
              inputs[:, this_view, ...] = 1.0

          inputs = inputs.to(device)
          labels = labels.to(device)

          # zero the parameter gradients
          optimizer.zero_grad()

          # forward + backward + optimize
          outputs = model(inputs)
          loss = criterion(outputs, labels)
          loss.backward()
          optimizer.step()

          # print statistics
          correct += (outputs.argmax(1) == labels).float().sum() #https://stackoverflow.com/questions/51503851/calculate-the-accuracy-every-epoch-in-pytorch
          total += labels.shape[0]
          running_loss += loss.item()

      print('epoch', epoch)
      
      print('{} seconds'.format(time.time() - t0))
      print('loss', running_loss)
      row.append(running_loss)
      running_loss = 0.0

      accuracy = 100 * correct / total
      print("Accuracy = {}".format(accuracy))
      row.append(torch.IntTensor.item(accuracy))
      correct = 0
      total= 0

      running_loss = 0.0
      correct = 0
      total = 0

      optimizer.zero_grad()
      model.eval()
      t0 = time.time()
      with torch.no_grad():
          for i, data in enumerate(val_loader):
              
              # get the inputs; data is a list of [inputs, labels]
              inputs, labels = data
              inputs = inputs.to(device)
              labels = labels.to(device)            


              # forward + backward + optimize
              outputs = model(inputs)
              loss = criterion(outputs, labels)

              correct += (outputs.argmax(1) == labels).float().sum() #https://stackoverflow.com/questions/51503851/calculate-the-accuracy-every-epoch-in-pytorch
              total += labels.shape[0]
              running_loss += loss.item()

      print('{} seconds'.format(time.time() - t0))
      print('val loss', running_loss)
      row.append(running_loss)
      

      accuracy = 100 * correct / total
      print(" Val Accuracy = {}".format(accuracy))
      row.append(torch.IntTensor.item(accuracy))

      model_name = 'ep_' + str(epoch) + '_loss_' + str(running_loss) + '_acc_' + str(accuracy) + '.pt'
      full_model_path = os.path.join(root_dir, model_name)
      checkpoint = {
          'epoch': epoch + 1,
          'state_dict': model.state_dict(),
          'optimizer': optimizer.state_dict()
      }
      if running_loss < low_val_loss:
        if val_loss_path and os.path.exists(val_loss_path):
          open(val_loss_path, 'w').close() #overwrite and make the file blank instead - ref: https://stackoverflow.com/a/4914288/3553367
          os.remove(val_loss_path)
        val_loss_path = full_model_path.replace('.pt', 'val_loss_best.pt')
        save_ckp(model, val_loss_path)
        low_val_loss = running_loss
        print('new low loss')

      if accuracy > best_acc:
        if acc_path and os.path.exists(acc_path):
          open(acc_path, 'w').close() #overwrite and make the file blank instead - ref: https://stackoverflow.com/a/4914288/3553367
          os.remove(acc_path)
        acc_path = full_model_path.replace('.pt', 'acc_best.pt')
        save_ckp(model, acc_path)
        best_acc = accuracy
        print('new best acc')
      full_model_path = os.path.join(root_dir, 'last.pt')
      save_ckp(checkpoint, full_model_path)
          

      

      correct = 0
      total= 0
      running_loss = 0.0
      writer.writerow(row)
  print('Finished Training')



epoch 1
80.97265481948853 seconds
loss 477.9483691453934
Accuracy = 36.2573127746582
3.3958404064178467 seconds
val loss 46.96899461746216
 Val Accuracy = 61.0
new low loss
new best acc
epoch 2
81.4390320777893 seconds
loss 416.12555158138275
Accuracy = 62.12754440307617
3.3754637241363525 seconds
val loss 42.77186191082001
 Val Accuracy = 77.5
new low loss
new best acc
epoch 3
81.40606117248535 seconds
loss 384.44534635543823
Accuracy = 78.27903747558594
3.382248878479004 seconds
val loss 41.63805150985718
 Val Accuracy = 80.25
new low loss
new best acc
epoch 4
80.98014187812805 seconds
loss 370.8237359523773
Accuracy = 83.98775482177734
3.3700966835021973 seconds
val loss 42.7879102230072
 Val Accuracy = 77.75
epoch 5
80.63347816467285 seconds
loss 364.8624566793442
Accuracy = 84.87886810302734
3.3684260845184326 seconds
val loss 40.38382017612457
 Val Accuracy = 84.75
new low loss
new best acc
epoch 6
80.81470489501953 seconds
loss 362.37051951885223
Accuracy = 85.51935577392578
3.3

In [None]:
import pandas as pd
     
# dictionary of lists  
this_dict = {'training loss': training_loss_list, 'training acc': training_acc_list, 'val loss': val_loss_list, 'val acc':val_acc_list}  
       
df = pd.DataFrame(this_dict) 
    
# saving the dataframe 
df.to_csv(os.path.join(root_dir, 'logs.csv'))

In [None]:
print('saved csv')

In [None]:
print(row)