Encode with the scaled data


In [13]:
!pip install cloud-tpu-client==0.10 https://storage.googleapis.com/tpu-pytorch/wheels/torch_xla-1.7-cp37-cp37m-linux_x86_64.whl

Collecting cloud-tpu-client==0.10
  Downloading https://files.pythonhosted.org/packages/56/9f/7b1958c2886db06feb5de5b2c191096f9e619914b6c31fdf93999fdbbd8b/cloud_tpu_client-0.10-py3-none-any.whl
Collecting torch-xla==1.7
[?25l  Downloading https://storage.googleapis.com/tpu-pytorch/wheels/torch_xla-1.7-cp37-cp37m-linux_x86_64.whl (133.6MB)
[K     |████████████████████████████████| 133.6MB 65kB/s 
Collecting google-api-python-client==1.8.0
[?25l  Downloading https://files.pythonhosted.org/packages/9a/b4/a955f393b838bc47cbb6ae4643b9d0f90333d3b4db4dc1e819f36aad18cc/google_api_python_client-1.8.0-py3-none-any.whl (57kB)
[K     |████████████████████████████████| 61kB 2.1MB/s 
Installing collected packages: google-api-python-client, cloud-tpu-client, torch-xla
  Found existing installation: google-api-python-client 1.7.12
    Uninstalling google-api-python-client-1.7.12:
      Successfully uninstalled google-api-python-client-1.7.12
Successfully installed cloud-tpu-client-0.10 google-api-

In [14]:
import torch
import numpy as np
import torch.nn as nn
from sklearn.decomposition import PCA
from torch.utils.data import DataLoader,random_split
from torch.utils.data.distributed import DistributedSampler

import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.distributed.parallel_loader as pl
import torch_xla.distributed.xla_multiprocessing as xmp



In [15]:
train_dataset = torch.load('raw_cv_whiten.pt')

In [16]:
def make_4_dim(data):
    data=torch.unsqueeze(data,1)
    data = torch.unsqueeze(data,3)
    return data

def make_2_dim(data):
    data=torch.squeeze(data,1)
    data = torch.squeeze(data,2)
    return data
def conv_block1(in_channels,out_channels,kernel_size,stride,padding):
    layers = [nn.ConvTranspose2d(in_channels=in_channels,out_channels=out_channels,kernel_size=kernel_size,
                                 stride=stride,padding=padding),
              nn.BatchNorm2d(out_channels),
              nn.LeakyReLU(negative_slope=0.2,inplace=True)]
    return nn.Sequential(*layers)

def conv_block2(in_channels,out_channels,kernel_size,stride,padding,pool=False):
    layers = [nn.Conv2d(in_channels=in_channels,out_channels=out_channels,kernel_size=kernel_size,
                        stride=stride,padding=padding),
              nn.BatchNorm2d(out_channels),
              nn.LeakyReLU(negative_slope=0.2,inplace=True)]
    if pool: layers.append(nn.MaxPool2d(2))
    return nn.Sequential(*layers)
    

class Resnet9(nn.Module):
      def __init__(self, in_channel, out_channel):
        super().__init__()
        
        # Encode1
        
        self.conv1 = conv_block1(in_channel, 64,kernel_size=4,stride=2,padding=0)         ## 64*6*5
        self.conv2 = conv_block1(64, 128,kernel_size=4,stride=2,padding=0)                ## 128*14*10
        self.conv3 = conv_block1(128, 256,kernel_size=4,stride=2,padding=0)               ##256*30*22
        self.res1 = nn.Sequential(conv_block1(256, 256,kernel_size=3,stride=1,padding=1), ## 256*30*22  
                                  conv_block1(256, 256,kernel_size=3,stride=1,padding=1))  ##256*30*22
        
        #Encode2
        self.conv4 = conv_block2(256, 128,kernel_size=4,stride=2,padding=0)         ## 128*14*10
        self.conv5 = conv_block2(128, 64,kernel_size=4,stride=2,padding=0)                ## 64*6*5
        #self.conv6 = conv_block2(128, 256,kernel_size=4,stride=2,padding=0)               ##256*30*22
        self.res2 = nn.Sequential(conv_block2(64,64,kernel_size=3,stride=1,padding=1), ## 256*30*22  
                                  conv_block2(64,64,kernel_size=3,stride=1,padding=1))  ##256*30*22
        self.conv6 = conv_block2(64,1,kernel_size=4,stride=2,padding=0)                 #1*2*1
        
        #Decode 1
        self.conv7 = conv_block1(in_channel, 64,kernel_size=4,stride=2,padding=0)         ## 64*6*5
        self.conv8 = conv_block1(64, 128,kernel_size=4,stride=2,padding=0)                ## 128*14*10
        self.conv9 = conv_block1(128, 256,kernel_size=4,stride=2,padding=0)               ##256*30*22
        self.res3 = nn.Sequential(conv_block1(256, 256,kernel_size=3,stride=1,padding=1), ## 256*30*22  
                                  conv_block1(256, 256,kernel_size=3,stride=1,padding=1))  ##256*30*22
        
        #Decode 2
        self.conv10 = conv_block2(256, 128,kernel_size=4,stride=2,padding=0)         ## 128*14*10
        self.conv11 = conv_block2(128, 64,kernel_size=4,stride=2,padding=0)                ## 64*6*5
        #self.conv6 = conv_block2(128, 256,kernel_size=4,stride=2,padding=0)               ##256*30*22
        self.res4 = nn.Sequential(conv_block2(64,64,kernel_size=3,stride=1,padding=1), ## 256*30*22  
                                  conv_block2(64,64,kernel_size=3,stride=1,padding=1))  ##256*30*22
        self.conv12 = nn.Conv2d(64,out_channel,kernel_size=4,stride=2,padding=0)                 #1*2*1
        
      def encode(self,in_data):
          out = self.conv1(in_data.float())
          out = self.conv2(out)
          out = self.conv3(out)
          out = self.res1(out)+out
          out = self.conv4(out)
          out = self.conv5(out)
          out = self.res2(out)+out
          out = self.conv6(out)
          return out
           
      def decode(self,lat_data):
          out = self.conv7(lat_data.float())
          out = self.conv8(out)
          out = self.conv9(out)
          out = self.res3(out)+out
          out = self.conv10(out)
          out = self.conv11(out)
          out = self.res4(out)+out
          out = self.conv12(out)
          return out

In [18]:
encoder = Resnet9(1,1)

In [19]:
encoder.load_state_dict(torch.load('encoder_state_dict_resnet.pth'))

<All keys matched successfully>

In [20]:
WRAPPED_MODEL = xmp.MpModelWrapper(encoder)

In [21]:
from tqdm.notebook import tqdm

In [22]:
FLAGS = {}
FLAGS['batch_size'] = 1028
FLAGS['num_workers'] = 4
FLAGS['num_cores'] = 8 


In [23]:
import gc
gc.collect()

525

In [None]:
def map_fn(rank,FLAGS):
    torch.manual_seed(1234)
    train_sampler = DistributedSampler(train_dataset,
                                       num_replicas=xm.xrt_world_size(),
                                       #num_replicas=8,
                                       rank= xm.get_ordinal(),
                                       shuffle=False)
    train_loader  = DataLoader(train_dataset,
                               batch_size=FLAGS['batch_size'],
                               sampler= train_sampler,
                               num_workers=FLAGS['num_workers'])
    device = xm.xla_device()
    encoder = WRAPPED_MODEL.to(device)
    latent = []
    x=int(4022018/7007)
    for i in tqdm(range(1,7008)):
        lat1 = make_2_dim(encoder.encode(make_4_dim(train_dataset[(i-1)*x:(i*x)].to(device))))
        latent.append(lat1)
        del lat1
        gc.collect()
        latent = torch.tensor(torch.cat(latent))
        xm.save(latent,'latent.pt')

if __name__ == '__main__':
          xmp.spawn(map_fn,args = (FLAGS,),nprocs=FLAGS['num_cores'],
          start_method='fork')