## Using multicore TPU to accelerate Neural Network Trianing





**Step 1.**
Install the pytorch XLA libreries to communicate with the TPU





In [None]:
import os
assert os.environ['COLAB_TPU_ADDR'], 'Make sure to select TPU from Edit > Notebook settings > Hardware accelerator'

In [None]:
!pip install cloud-tpu-client==0.10 https://storage.googleapis.com/tpu-pytorch/wheels/torch_xla-1.8-cp37-cp37m-linux_x86_64.whl

Collecting cloud-tpu-client==0.10
  Downloading https://files.pythonhosted.org/packages/56/9f/7b1958c2886db06feb5de5b2c191096f9e619914b6c31fdf93999fdbbd8b/cloud_tpu_client-0.10-py3-none-any.whl
Collecting torch-xla==1.7
[?25l  Downloading https://storage.googleapis.com/tpu-pytorch/wheels/torch_xla-1.7-cp37-cp37m-linux_x86_64.whl (133.6MB)
[K     |████████████████████████████████| 133.6MB 41kB/s 
Collecting google-api-python-client==1.8.0
[?25l  Downloading https://files.pythonhosted.org/packages/9a/b4/a955f393b838bc47cbb6ae4643b9d0f90333d3b4db4dc1e819f36aad18cc/google_api_python_client-1.8.0-py3-none-any.whl (57kB)
[K     |████████████████████████████████| 61kB 3.0MB/s 
Installing collected packages: google-api-python-client, cloud-tpu-client, torch-xla
  Found existing installation: google-api-python-client 1.7.12
    Uninstalling google-api-python-client-1.7.12:
      Successfully uninstalled google-api-python-client-1.7.12
Successfully installed cloud-tpu-client-0.10 google-api-

In [None]:
import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader,random_split
from torch.utils.data.distributed import DistributedSampler
import matplotlib.pyplot as plt
%matplotlib inline
from sklearn.decomposition import PCA

# import torch xla APIs

import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.distributed.parallel_loader as pl
import torch_xla.distributed.xla_multiprocessing as xmp



In [None]:
#train_size=int(0.4*(unprocessed_data.shape[0]))
# val_size=len(unprocessed_data) - train_size
# print(train_size,val_size)
# train_ds,_ = random_split(unprocessed_data,[train_size,val_size])

**Step 2.** Make the data mean-free and whiten

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
train_dataset =  torch.load('/content/drive/MyDrive/Autoencoding/raw_cv_whiten.pt')


In [None]:
# Define Parameters
FLAGS = {}
#FLAGS['data_dir'] = "/tmp/cifar"
FLAGS['batch_size'] = 1028
FLAGS['num_workers'] = 4
FLAGS['max_learning_rate'] = 0.001
FLAGS['grad_clip']  = 0.1
FLAGS['weight_decay'] = 1e-4
FLAGS['opt_func']  = torch.optim.Adam
#FLAGS['momentum'] = 0.9
FLAGS['num_epochs'] = 21
FLAGS['num_cores'] = 8 
FLAGS['log_steps'] = 20
FLAGS['metrics_debug'] = False

** warp the dataloader for parallelization**

In [None]:
#SERIAL_EXEC = xmp.MpSerialExecutor()

In [None]:
#train_dl = DataLoader(train_ds,batch_size,shuffle=True)

In [None]:
def make_4_dim(data):
    data=torch.unsqueeze(data,1)
    data = torch.unsqueeze(data,3)
    return data

def make_2_dim(data):
    data=torch.squeeze(data,1)
    data = torch.squeeze(data,2)
    return data

In [None]:
def conv_block1(in_channels,out_channels,kernel_size,stride,padding):
    layers = [nn.ConvTranspose2d(in_channels=in_channels,out_channels=out_channels,kernel_size=kernel_size,
                                 stride=stride,padding=padding),
              nn.BatchNorm2d(out_channels),
              nn.LeakyReLU(negative_slope=0.2,inplace=True)]
    return nn.Sequential(*layers)

def conv_block2(in_channels,out_channels,kernel_size,stride,padding,pool=False):
    layers = [nn.Conv2d(in_channels=in_channels,out_channels=out_channels,kernel_size=kernel_size,
                        stride=stride,padding=padding),
              nn.BatchNorm2d(out_channels),
              nn.LeakyReLU(negative_slope=0.2,inplace=True)]
    if pool: layers.append(nn.MaxPool2d(2))
    return nn.Sequential(*layers)
    


In [None]:
class Resnet9(nn.Module):
      def __init__(self, in_channel, out_channel):
        super().__init__()
        
        # Encode1
        
        self.conv1 = conv_block1(in_channel, 64,kernel_size=4,stride=2,padding=0)         ## 64*6*5
        self.conv2 = conv_block1(64, 128,kernel_size=4,stride=2,padding=0)                ## 128*14*10
        self.conv3 = conv_block1(128, 256,kernel_size=4,stride=2,padding=0)               ##256*30*22
        self.res1 = nn.Sequential(conv_block1(256, 256,kernel_size=3,stride=1,padding=1), ## 256*30*22  
                                  conv_block1(256, 256,kernel_size=3,stride=1,padding=1))  ##256*30*22
        
        #Encode2
        self.conv4 = conv_block2(256, 128,kernel_size=4,stride=2,padding=0)         ## 128*14*10
        self.conv5 = conv_block2(128, 64,kernel_size=4,stride=2,padding=0)                ## 64*6*5
        #self.conv6 = conv_block2(128, 256,kernel_size=4,stride=2,padding=0)               ##256*30*22
        self.res2 = nn.Sequential(conv_block2(64,64,kernel_size=3,stride=1,padding=1), ## 256*30*22  
                                  conv_block2(64,64,kernel_size=3,stride=1,padding=1))  ##256*30*22
        self.conv6 = conv_block2(64,1,kernel_size=4,stride=2,padding=0)                 #1*2*1
        
        #Decode 1
        self.conv7 = conv_block1(in_channel, 64,kernel_size=4,stride=2,padding=0)         ## 64*6*5
        self.conv8 = conv_block1(64, 128,kernel_size=4,stride=2,padding=0)                ## 128*14*10
        self.conv9 = conv_block1(128, 256,kernel_size=4,stride=2,padding=0)               ##256*30*22
        self.res3 = nn.Sequential(conv_block1(256, 256,kernel_size=3,stride=1,padding=1), ## 256*30*22  
                                  conv_block1(256, 256,kernel_size=3,stride=1,padding=1))  ##256*30*22
        
        #Decode 2
        self.conv10 = conv_block2(256, 128,kernel_size=4,stride=2,padding=0)         ## 128*14*10
        self.conv11 = conv_block2(128, 64,kernel_size=4,stride=2,padding=0)                ## 64*6*5
        #self.conv6 = conv_block2(128, 256,kernel_size=4,stride=2,padding=0)               ##256*30*22
        self.res4 = nn.Sequential(conv_block2(64,64,kernel_size=3,stride=1,padding=1), ## 256*30*22  
                                  conv_block2(64,64,kernel_size=3,stride=1,padding=1))  ##256*30*22
        self.conv12 = nn.Conv2d(64,out_channel,kernel_size=4,stride=2,padding=0)                 #1*2*1
        
      def encode(self,in_data):
          out = self.conv1(in_data.float())
          out = self.conv2(out)
          out = self.conv3(out)
          out = self.res1(out)+out
          out = self.conv4(out)
          out = self.conv5(out)
          out = self.res2(out)+out
          out = self.conv6(out)
          return out
           
      def decode(self,lat_data):
          out = self.conv7(lat_data.float())
          out = self.conv8(out)
          out = self.conv9(out)
          out = self.res3(out)+out
          out = self.conv10(out)
          out = self.conv11(out)
          out = self.res4(out)+out
          out = self.conv12(out)
          return out
      
        

In [None]:
WRAPPED_MODEL = xmp.MpModelWrapper(Resnet9(1,1))

In [None]:
# # def get_default_device():
# #     """Pick GPU if available, else CPU"""
# #     if torch.cuda.is_available():
# #         return torch.device('cuda')
# #     else:
# #         return torch.device('cpu')
    
# def to_device(data, device):
#     """Move tensor(s) to chosen device"""
#     if isinstance(data, (list,tuple)):
#         return [to_device(x, device) for x in data]
#     return data.to(device, non_blocking=True)

# class DeviceDataLoader():
#     """Wrap a dataloader to move data to a device"""
#     def __init__(self, dl, device):
#         self.dl = dl
#         self.device = device
        
#     def __iter__(self):
#         """Yield a batch of data after moving it to device"""
#         for b in self.dl: 
#             yield to_device(b, self.device)

#     def __len__(self):
#         """Number of batches"""
#         return len(self.dl)

In [None]:
def training_step(data,encoder):
    training_dat = data[:,0:2]
    training_dat = make_4_dim(training_dat)
    shifted_dat = data[:,2:4]
    out = encoder.encode(training_dat)
    out = encoder.decode(out)
    out = make_2_dim(out)
    criterion = nn.MSELoss()
    loss = criterion(out.float(),shifted_dat.float())
    return loss

In [None]:
# def training_step(data,encoder):
#     training_dat = data
#     training_dat = make_4_dim(training_dat)
#     out = encoder.encode(training_dat)
#     out = encoder.decode(out)
#     out = make_2_dim(out)
#     criterion = nn.MSELoss()
#     loss = criterion(out.float(),data.float())
#     return loss

# def evaluate(model,val_loader):
#     for data in val_loader:
#         val_dat = data
#         val_dat = make_4_dim(val_dat)
#         out = model.encode(val_dat)
#         out = model.decode(out)
#         out = make_2_dim(out)
#         criterion = nn.MSELoss()
#         loss = criterion(out.float(),data.float())
#         return {'val_loss':loss}

In [None]:
from tqdm.notebook import tqdm

In [None]:
lag = 8
train_data = train_dataset[:-lag]
shifted_train_data = train_dataset[lag:]
transformed_train_data = torch.hstack((train_data,shifted_train_data))


In [None]:
# def prepare_data(unprocessed_data):

#       train_ds = unprocessed_data[:]
#       train_ds = train_ds -torch.mean(train_ds,0) # Mean free

#       pca_whiten = PCA(whiten=True)                     # Whiten data
#       train_ds = pca_whiten.fit_transform(train_ds)

#       train_ds = torch.tensor(train_ds)                  # as tensor
#       return train_ds

In [None]:
#@torch.no_grad()
def fit_one_cycle(FLAGS):
    
    torch.manual_seed(1234)
    history = []

    

    #train_dataset = SERIAL_EXEC.run(prepare_data(unprocessed_data))
    #train_dataset = prepare_data(unprocessed_data)
    train_sampler = DistributedSampler(transformed_train_data,
                                       num_replicas=xm.xrt_world_size(),
                                       #num_replicas=8,
                                       rank= xm.get_ordinal(),
                                       shuffle=True)
    train_loader  = DataLoader(transformed_train_data,
                               batch_size=FLAGS['batch_size'],
                               sampler= train_sampler,
                               num_workers=FLAGS['num_workers'])
    
    # Scale learning rate to no of torch devices
    max_lr = FLAGS['max_learning_rate']*xm.xrt_world_size()

    # Get loss function, optimizer, and model
    device = xm.xla_device()
    encoder = WRAPPED_MODEL.to(device)
    opt_func = FLAGS['opt_func']
    # set up custom optimizer with weight decay
    optimizer = opt_func(encoder.parameters(),
                         FLAGS['max_learning_rate'],
                         weight_decay=FLAGS['weight_decay'])
    
    # set up one_cycle learning rate scheduler
    sched = torch.optim.lr_scheduler.OneCycleLR(optimizer,
                                                FLAGS['max_learning_rate'],
                                               epochs=FLAGS['num_epochs'],
                                               steps_per_epoch=len(train_loader))
    
    
    # crate the loop for training
    
    for epoch in range(FLAGS['num_epochs']):
        # Training Phase
        encoder.train()
        train_losses = []
        para_loader = pl.ParallelLoader(train_loader,[device]).per_device_loader(device)
        #train_loop_fn(para_loader.per_device_loader(device))
        xm.master_print("Finished training epoch {}".format(epoch))
        
        for batch in tqdm(para_loader):
            
            loss = training_step(batch,encoder)
            train_losses.append(loss)
            loss.backward()
            
            # Gradient clipping
            #if grad_clip:
            nn.utils.clip_grad_value_(encoder.parameters(),
                                         FLAGS['grad_clip'])
            xm.optimizer_step(optimizer)
            optimizer.zero_grad()
            
            sched.step()
        # Validation Phase    
        #result = evaluate(model,val_loader)
        train_loss = torch.stack(train_losses).mean().item()
        print('train_loss{:.4f}'.format(train_loss))
        history.append(train_loss)
        history_ten = torch.tensor(history)
        xm.save(encoder.state_dict(),'encoder_state_dict_resnet.pth')
        xm.save(history_ten,'history_resnet.pth')
    return history_ten

## Make the Map Function 

In [None]:
def map_fn(rank,flags):
    global FLAGS
    FLAGS = flags 
    torch.set_default_tensor_type('torch.FloatTensor')
    #history = []
    history_ten = fit_one_cycle(FLAGS)
    # if rank == 0:
    # # Retrieve tensors that are on TPU core 0 and plot.
    #    xm.save(encoder_state_dict,'encoder_state_dict.pth')

if __name__ == '__main__':
          xmp.spawn(map_fn,args = (FLAGS,),nprocs=FLAGS['num_cores'],
          start_method='fork')

Finished training epoch 0


HBox(children=(FloatProgress(value=0.0, max=490.0), HTML(value='')))









train_loss0.1432
train_loss0.1442
train_loss0.1443
train_loss0.1432
train_loss0.1427
train_loss0.1443
train_loss0.1438
train_loss0.1438
Finished training epoch 1


HBox(children=(FloatProgress(value=0.0, max=490.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=490.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=490.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=490.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=490.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=490.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=490.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=490.0), HTML(value='')))









train_loss0.1107
train_loss0.1105
train_loss0.1105
train_loss0.1099
train_loss0.1103
train_loss0.1097
train_loss0.1104
train_loss0.1101
Finished training epoch 2


HBox(children=(FloatProgress(value=0.0, max=490.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=490.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=490.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=490.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=490.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=490.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=490.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=490.0), HTML(value='')))









train_loss0.1115
train_loss0.1112
train_loss0.1105
train_loss0.1110
train_loss0.1107
train_loss0.1114
train_loss0.1113
train_loss0.1113
Finished training epoch 3


HBox(children=(FloatProgress(value=0.0, max=490.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=490.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=490.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=490.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=490.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=490.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=490.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=490.0), HTML(value='')))









train_loss0.1116
train_loss0.1106
train_loss0.1114
train_loss0.1109
train_loss0.1114
train_loss0.1112
train_loss0.1113
train_loss0.1116
Finished training epoch 4


HBox(children=(FloatProgress(value=0.0, max=490.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=490.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=490.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=490.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=490.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=490.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=490.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=490.0), HTML(value='')))









train_loss0.1116
train_loss0.1110
train_loss0.1113
train_loss0.1114
train_loss0.1117
train_loss0.1117
train_loss0.1116
train_loss0.1108
Finished training epoch 5


HBox(children=(FloatProgress(value=0.0, max=490.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=490.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=490.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=490.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=490.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=490.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=490.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=490.0), HTML(value='')))









train_loss0.1112
train_loss0.1111
train_loss0.1106
train_loss0.1109
train_loss0.1114
train_loss0.1104
train_loss0.1112
train_loss0.1113
Finished training epoch 6


HBox(children=(FloatProgress(value=0.0, max=490.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=490.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=490.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=490.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=490.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=490.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=490.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=490.0), HTML(value='')))









train_loss0.1103
train_loss0.1106
train_loss0.1107
train_loss0.1100
train_loss0.1106
train_loss0.1105
train_loss0.1108
train_loss0.1098
Finished training epoch 7


HBox(children=(FloatProgress(value=0.0, max=490.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=490.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=490.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=490.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=490.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=490.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=490.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=490.0), HTML(value='')))









train_loss0.1102
train_loss0.1099
train_loss0.1104
train_loss0.1094
train_loss0.1102
train_loss0.1096
train_loss0.1103
train_loss0.1101
Finished training epoch 8


HBox(children=(FloatProgress(value=0.0, max=490.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=490.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=490.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=490.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=490.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=490.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=490.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=490.0), HTML(value='')))









train_loss0.1097
train_loss0.1100
train_loss0.1101
train_loss0.1099
train_loss0.1092
train_loss0.1098
train_loss0.1094
train_loss0.1101
Finished training epoch 9


HBox(children=(FloatProgress(value=0.0, max=490.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=490.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=490.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=490.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=490.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=490.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=490.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=490.0), HTML(value='')))









train_loss0.1097
train_loss0.1098
train_loss0.1092
train_loss0.1100
train_loss0.1090
train_loss0.1099
train_loss0.1097
train_loss0.1095
Finished training epoch 10


HBox(children=(FloatProgress(value=0.0, max=490.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=490.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=490.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=490.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=490.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=490.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=490.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=490.0), HTML(value='')))









train_loss0.1090
train_loss0.1095
train_loss0.1088
train_loss0.1093
train_loss0.1096
train_loss0.1097
train_loss0.1098
train_loss0.1096
Finished training epoch 11


HBox(children=(FloatProgress(value=0.0, max=490.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=490.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=490.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=490.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=490.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=490.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=490.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=490.0), HTML(value='')))









train_loss0.1098
train_loss0.1093
train_loss0.1087
train_loss0.1090
train_loss0.1095
train_loss0.1096
train_loss0.1095
train_loss0.1094
Finished training epoch 12


HBox(children=(FloatProgress(value=0.0, max=490.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=490.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=490.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=490.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=490.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=490.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=490.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=490.0), HTML(value='')))









train_loss0.1091
train_loss0.1097
train_loss0.1086
train_loss0.1094
train_loss0.1088
train_loss0.1093
train_loss0.1095
train_loss0.1093
Finished training epoch 13


HBox(children=(FloatProgress(value=0.0, max=490.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=490.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=490.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=490.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=490.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=490.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=490.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=490.0), HTML(value='')))









train_loss0.1085
train_loss0.1096
train_loss0.1092
train_loss0.1090
train_loss0.1092
train_loss0.1094
train_loss0.1093
train_loss0.1087
Finished training epoch 14


HBox(children=(FloatProgress(value=0.0, max=490.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=490.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=490.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=490.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=490.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=490.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=490.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=490.0), HTML(value='')))









train_loss0.1094
train_loss0.1090
train_loss0.1092
train_loss0.1088
train_loss0.1086
train_loss0.1093
train_loss0.1092
train_loss0.1084
Finished training epoch 15


HBox(children=(FloatProgress(value=0.0, max=490.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=490.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=490.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=490.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=490.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=490.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=490.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=490.0), HTML(value='')))









train_loss0.1089
train_loss0.1091
train_loss0.1093
train_loss0.1088
train_loss0.1090
train_loss0.1083
train_loss0.1085
train_loss0.1092
Finished training epoch 16


HBox(children=(FloatProgress(value=0.0, max=490.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=490.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=490.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=490.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=490.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=490.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=490.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=490.0), HTML(value='')))









train_loss0.1086
train_loss0.1088
train_loss0.1090
train_loss0.1081
train_loss0.1087
train_loss0.1083
train_loss0.1089
train_loss0.1090
Finished training epoch 17


HBox(children=(FloatProgress(value=0.0, max=490.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=490.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=490.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=490.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=490.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=490.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=490.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=490.0), HTML(value='')))









train_loss0.1088
train_loss0.1082
train_loss0.1087
train_loss0.1085
train_loss0.1090
train_loss0.1089
train_loss0.1080
train_loss0.1087
Finished training epoch 18


HBox(children=(FloatProgress(value=0.0, max=490.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=490.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=490.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=490.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=490.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=490.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=490.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=490.0), HTML(value='')))









train_loss0.1079
train_loss0.1086
train_loss0.1086
train_loss0.1089
train_loss0.1081
train_loss0.1084
train_loss0.1087
train_loss0.1088
Finished training epoch 19


HBox(children=(FloatProgress(value=0.0, max=490.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=490.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=490.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=490.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=490.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=490.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=490.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=490.0), HTML(value='')))









train_loss0.1088
train_loss0.1078
train_loss0.1085
train_loss0.1085
train_loss0.1080
train_loss0.1083
train_loss0.1086
train_loss0.1087
Finished training epoch 20


HBox(children=(FloatProgress(value=0.0, max=490.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=490.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=490.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=490.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=490.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=490.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=490.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=490.0), HTML(value='')))









train_loss0.1084
train_loss0.1087
train_loss0.1077
train_loss0.1079
train_loss0.1086
train_loss0.1085
train_loss0.1083
train_loss0.1085


In [None]:
encoder = Resnet9(1,1);decoder =Resnet9(1,1)

In [None]:
decoder.load_state_dict(torch.load('encoder_state_dict.pth'))
encoder.load_state_dict(torch.load('encoder_state_dict_1.pth'))

In [None]:
torch.all(torch.eq(decoder.state_dict(),encoder.state_dict()))

In [None]:
encoder.encode

In [None]:
import torch
history = torch.load('history_4.pth')


In [None]:
history