In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision import transforms
import matplotlib.pyplot as plt
from torch.optim import Adam
from tqdm import tqdm
from torch.utils.data import DataLoader
import random
from collections import Counter
import math
from PIL import Image
from pathlib import Path


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
batch_size = 32
max_epoch = 10
device = torch.device(
    'cuda') if torch.cuda.is_available() else torch.device('cpu')
train_data_path = './numpy_datasets/train'
test_data_path = './numpy_datasets/test'
image_path = '../Images'
data_save_path = './TransNets'
lr = 0.000

In [None]:
class MyDataset(torch.utils.data.Dataset):
  def __init__(self, root):
    self.transform = transforms.Compose([
      transforms.Resize((224,224)),
      transforms.CenterCrop(224),
      transforms.ToTensor(),
      transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    images_path = Path(root)

    images_list = list(image_path.glob('*.jpg'))
    images_list_str = [str(x) for x in images_list]
    self.images = images_list_str


  def __getitem__(self, item):
    pass
    

In [5]:
import os
os.path

<module 'ntpath' from 'c:\\Users\\zhouzihan\\.conda\\envs\\project\\lib\\ntpath.py'>

In [3]:
train_users = np.load(f'{train_data_path}/train_users.npy')
train_items = np.load(f'{train_data_path}/train_items.npy')
train_ratings = np.load(f'{train_data_path}/train_ratings.npy')
train_reviews = np.load(
    f'{train_data_path}/train_reviews.npy', allow_pickle=True)
train_descriptions = np.load(
    f'{train_data_path}/train_descriptions.npy', allow_pickle=True)
train_prices = np.load(f'{train_data_path}/train_prices.npy')
train_categories = np.load(f'{train_data_path}/train_categories.npy')

test_users = np.load(f'{test_data_path}/test_users.npy')
test_items = np.load(f'{test_data_path}/test_items.npy')
test_ratings = np.load(f'{test_data_path}/test_ratings.npy')
test_reviews = np.load(f'{test_data_path}/test_reviews.npy', allow_pickle=True)
test_descriptions = np.load(
    f'{test_data_path}/test_descriptions.npy', allow_pickle=True)
test_prices = np.load(f'{test_data_path}/test_prices.npy')
test_categories = np.load(f'{test_data_path}/test_categories.npy')

FileNotFoundError: [Errno 2] No such file or directory: './numpy_datasets/train/train_users.npy'

In [None]:
def cycle(iterable):
  while True:
    for x in iterable:
      yield x


train_iterator = cycle(DataLoader(
    np.unique(train_users), batch_size=batch_size, shuffle=True))

test_iterator = cycle(DataLoader(
    np.unique(test_users), batch_size=batch_size, shuffle=True
))

In [None]:
image_loader = cycle(DataLoader(torchvision.datasets.ImageFolder(
  '../Images',
  transform=torchvision.transforms.Compose([
    torchvision.transforms.Resize((224,224)),
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize(mean=[.5, .5, .5], std=[.5, .5, .5])
  ])
))
)

In [6]:
'''
  input: user u, sets
  item: i
  output: user_reviews <- [reviews of u],  
          user[descriptions that u has seen], 
          most liked category of u, 
          average price of u purchasing items, 
          item_reviews <- [reviews of i],
          item_description <- description of i,
          item_price <- price of i,
          item_image <- image of i,
          unseen_image <- the image of an unseen item to u,
          rev_ui <- rating of u on i

  status: this function works fine
'''


def sample_user(user, user_set, item_set, rev_set, desc_set, cat_set, price_set, img_set, rating_set):
  # -> u
  user = int(user)
  user_meta = []
  item_meta = []
  user_reviews = []
  user_descriptions =[]
  user_category = []
  user_price = []
  item_reviews = []
  item_description = None
  item_category = None
  item_price = None
  item_image = None
  unseen_image = None
  rev_ui = None

  user_indicies = np.where(user_set == user)[0]

  # generates the index of a rated item for the user
  chosen_user_item_idx = user_indicies[random.randint(0, len(user_indicies)-1)]

  # randomly select a rated item -> i
  item = item_set[chosen_user_item_idx]
  item_description = desc_set[chosen_user_item_idx]
  item_category = cat_set[chosen_user_item_idx]
  item_price = price_set[chosen_user_item_idx]

  # gets the rating of u, i
  rating_ui = rating_set[chosen_user_item_idx]

  # gets item indices to get all reviews of the item -> i
  item_indicies = np.where(item_set == item)[0]

  # all indicies in arrays are same (identical to user review number)
  for u_idx in user_indicies:
    # out of reviews, other metadata is always accessible either or not in test mode
    user_descriptions.append(desc_set[u_idx])
    user_category.append(cat_set[u_idx])
    user_price.append(price_set[u_idx])
    if u_idx != chosen_user_item_idx:
      user_reviews.append(rev_set[u_idx])
    else:
      rev_ui = rev_set[u_idx]

  for i_idx in item_indicies:
    if i_idx != chosen_user_item_idx:
      item_reviews.append(rev_set[i_idx])


  user_reviews, user_descriptions, item_reviews = \
    np.array(
      user_reviews).flatten(), np.array(user_descriptions).flatten(), np.array(item_reviews).flatten()

  # according to the transnets paper - regularise the combined strings
  if len(user_reviews) > 1000:
    user_reviews = user_reviews[:1000]
  if len(user_descriptions) > 1000:
    user_descriptions = user_descriptions[:1000]
  if len(item_reviews) > 1000:
    item_reviews = item_reviews[:1000]

  if len(user_reviews) == 0:
    user_reviews = np.zeros((64,))
  if len(item_reviews) == 0:
    item_reviews = np.zeros((64,))
  if len(user_descriptions) == 0:
    user_descriptions = np.zeros((64,))

  # for computational correctness
  user_reviews, user_descriptions, item_reviews, item_description = torch.from_numpy(np.array(user_reviews)), torch.from_numpy(np.array(user_descriptions)),\
      torch.from_numpy(np.array(item_reviews)), torch.from_numpy(np.array(item_description))
  
  user_category = Counter(user_category).most_common(1)[0][0]
  user_price = sum(user_price)/len(user_price)
  
  rev_ui = torch.from_numpy(np.array(rev_ui))

  user_meta = [torch.tensor(user_category), torch.tensor([user_price],dtype=torch.float32), user_descriptions, user_reviews]
  item_meta = [torch.tensor(item), torch.tensor(item_category), torch.tensor([item_price], dtype=torch.float32), item_description, item_reviews]

  return user_meta, item_meta, item_image, unseen_image, rev_ui, rating_ui

def l1_loss(pred, y):
  return torch.mean(y - pred)


def l2_loss(pred, y):
  return torch.mean(torch.pow(pred,2) - torch.pow(y, 2))


def weights_init(m):
  if type(m) in (nn.Linear, nn.Conv1d, nn.Conv2d, nn.Embedding):
    nn.init.xavier_normal_(m.weight)


In [7]:
class FM(nn.Module):
    def __init__(self, latent_dim, fea_num):
        super().__init__()

        self.latent_dim = latent_dim
        self.w0 = nn.Parameter(torch.zeros([1, ]))
        self.w1 = nn.Parameter(torch.rand([fea_num, 1]))
        self.w2 = nn.Parameter(torch.rand([fea_num, latent_dim]))

    def forward(self, inputs):
        # inputs = inputs.long()
        first_order = self.w0 + torch.mm(inputs, self.w1)
        second_order = 1/2 * torch.sum(
            torch.pow(torch.mm(inputs, self.w2), 2) -
            torch.mm(torch.pow(inputs, 2), torch.pow(self.w2, 2)),

            dim=1,
            keepdim=True
        )

        return first_order + second_order


class GMF(nn.Module):
    def __init__(self, inp_range, latent_dim=20, dropout = True):
        super().__init__()
        self.emb = nn.Embedding(
            num_embeddings=inp_range, embedding_dim=latent_dim)
        self.dropout = nn.Dropout(0.5)
        self.use_dropout = dropout

    def forward(self, inputs):
        embedding = self.emb(inputs)
        if self.use_dropout:
          embedding = self.dropout(embedding)

        return embedding

class TextCNN(nn.Module):
  def __init__(self, kernel_size, neurons, latent_vector_size, vocab_size, embed_size):
      super().__init__()
      self.embedding = nn.Embedding(vocab_size, embed_size)
      self.dropout = nn.Dropout(0.5)
      self.decoder = nn.Linear(300, latent_vector_size)
      self.pooling = nn.AdaptiveMaxPool1d(1)
      self.convs = nn.ModuleList()
      self.tanh = nn.Tanh()
      self.embed_size = embed_size
      for _ in range(3):
        self.convs.append(nn.Conv1d(embed_size, neurons, kernel_size))

  def forward(self, inputs):
    embeddings = self.embedding(inputs.long())

    print(embeddings.shape)
    embeddings = embeddings.view(1, len(embeddings), self.embed_size).permute(0, 2, 1)

    encoding = torch.cat([
        torch.squeeze(self.pooling(conv(embeddings)), dim=-1)
        for conv in self.convs
    ], dim=1)

    out = self.decoder(self.dropout(encoding))
    return out


class TransformMLP(nn.Module):
  def __init__(self, concated_size, latent_vector_size):
      super().__init__()
      self.net = nn.Sequential(
          nn.Linear(concated_size, 2*concated_size),
          nn.ReLU(inplace=True),
          nn.Linear(2*concated_size, latent_vector_size),
          nn.Dropout(0.5, inplace=True)
      )

  def forward(self, x):
    out = self.net(x)
    out = out.view(1, out.shape[0])
    return out


# https://d2l.ai/chapter_convolutional-modern/resnet.html
class Residual(nn.Module):
  def __init__(self, input_channels, num_channels, use_1x1_conv=False, downsample=1):
    super().__init__()
    self.net = nn.Sequential(
        nn.Conv2d(input_channels, num_channels, kernel_size=3, padding=1,stride=downsample),
        nn.BatchNorm2d(num_channels),
        nn.ReLU(inplace=True),
        nn.Conv2d(num_channels, num_channels, kernel_size=3, padding=1),
        nn.BatchNorm2d(num_channels)
    )
    # reduce image size by 2
    if use_1x1_conv:
      self.net2 = nn.Conv2d(input_channels, num_channels,
                            kernel_size=1, stride=downsample)
    else:
      self.net2 = None

  def forward(self, x):
    res = self.net(x)
    if self.net2 != None:
      x = self.net2(x)
    res += x
    out = F.relu(res)
    return out

def resnet_block(input_channels, num_channels, num_residuals, first_block=False):
  blk = []
  for i in range(num_residuals):
    if i == 0 and not first_block:
      blk.append(Residual(input_channels, num_channels,
                 use_1x1_conv=True, downsample=2))
    else:
      blk.append(Residual(num_channels, num_channels))
  return blk


# resnet siamesecnn

'''
Here, instead of seeking a final layer that can be adapted to general-purpose prediction tasks, 
we hope to learn a representation whose dimensions explain the variance in users\' fashion preferences
'''
class SiameseCNN(nn.Module):
  def __init__(self):
    super().__init__()
    self.net = nn.Sequential(
        # b1
        nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3),  # 112
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1),  # 56
        ),

        # b2 - first block
        nn.Sequential(*resnet_block(64, 64, 2, first_block=True)), # x. + y.  56 / 4 + 56 / 4  = 28

        # b3
        nn.Sequential(*resnet_block(64, 128, 2)), # out: 7 + 7 = 14

        # b4
        nn.Sequential(*resnet_block(128, 256, 2)), # out: 3.5 + 3.5 = 7

        # b5
        nn.Sequential(*resnet_block(256, 512, 2)), # out: 3.5 ?

        nn.AdaptiveAvgPool2d((1, 1)),
        nn.Flatten(),
        nn.Linear(512, 512),
        nn.Dropout(inplace=True),
        nn.Linear(512, 100)
    )

  def forward(self, x1, x2):
    return self.net(x1), self.net(x2)

textCNN_I = TextCNN(3, 100, 50, 50000, 64).apply(weights_init).to(device)
textCNN_U = TextCNN(3, 100, 50, 50000, 64).apply(weights_init).to(device)
textCNN_T = TextCNN(3, 100, 50, 50000, 64).apply(weights_init).to(device)


transform = TransformMLP(100, 50).apply(weights_init).to(device)
fm_T = FM(8, 50).apply(weights_init).to(device)
fm_S = FM(8, 220).apply(weights_init).to(device)

'''
Meta data
'''
# user latent vectors -> this embedding should be updated
mf_u = GMF(max(train_users)+1, 10, True).apply(weights_init).to(device)
# item latent vectors -> this embedding should be updated
mf_i = GMF(max(max(train_items), max(test_items))+1, 10, True).apply(weights_init).to(device)

# user description latent vectors -> this nn should be updated
textCNN_DU = TextCNN(3, 100, 64, 50000, 36).apply(weights_init).to(device)
# item description latent vectors -> this nn should be updated
textCNN_DI = TextCNN(3, 100, 36, 50000, 36).apply(weights_init).to(device)
# image latent vectors -> this nn should be updated
imageCNN = SiameseCNN().apply(weights_init).to(device)

optimiser_I = torch.optim.Adam(params=textCNN_I.parameters(), lr=lr)
optimiser_U = torch.optim.Adam(params=textCNN_U.parameters(), lr=lr)
optimiser_T = torch.optim.Adam(params=textCNN_T.parameters(), lr=lr)
optimiser_DU = torch.optim.Adam(params=textCNN_DU.parameters(), lr=lr)
optimiser_DI = torch.optim.Adam(params=textCNN_DI.parameters(), lr=lr)
optimiser_trans = torch.optim.Adam(params=transform.parameters(), lr=lr)
optimiser_FMT = torch.optim.Adam(params=fm_T.parameters(), lr=lr)
optimiser_FMS = torch.optim.Adam(params=fm_S.parameters(), lr=lr)
optimiser_image = torch.optim.Adam(params=imageCNN.parameters(), lr = lr)
optimiser_MFU = torch.optim.Adam(params=mf_u.parameters(), lr = lr)
optimiser_MFI = torch.optim.Adam(params=mf_i.parameters(), lr = lr)


def save_training():
  torch.save(
      {'textCNNI': textCNN_I.state_dict(),
       'textCNNU': textCNN_U.state_dict(),
       'textCNNT': textCNN_T.state_dict(),
       'transform': transform.state_dict(),
       'fmT': fm_T.state_dict(),
       'fmS': fm_S.state_dict(),
       'mf_u': mf_u.state_dict(),
       'mf_i':mf_i.state_dict(),
       'textCNN_DU':textCNN_DU.state_dict(),
       'textCNN_DI':textCNN_DI.state_dict(),
       'imageCNN':imageCNN.state_dict(),
       'epoch': epoch},
      f'{data_save_path}/save.chkpt')


def load_training():
  global textCNN_I
  global textCNN_U
  global textCNN_T
  global transform
  global fm_T
  global fm_S
  global mf_u
  global mf_i
  global textCNN_DU
  global textCNN_DI
  global imageCNN

  params = torch.load(f'{data_save_path}/save.chkpt')

  textCNN_I.load_state_dict(params['textCNNI'])
  textCNN_U.load_state_dict(params['textCNNU'])
  textCNN_T.load_state_dict(params['textCNNT'])
  transform.load_state_dict(params['transform'])
  fm_T.load_state_dict(params['fmT'])
  fm_S.load_state_dict(params['fmS'])
  mf_u.load_state_dict(params['fm_u'])
  mf_i.load_state_dict(params['fm_i'])
  textCNN_DU.load_state_dict(params['textCNN_DU'])
  textCNN_DI.load_state_dict(params['textCNN_DI'])
  imageCNN.load_state_dict(params['imageCNN'])


NameError: name 'train_users' is not defined

In [None]:
def eval_rmse(preds, real):
    rmse = 0
    # produce known rui set
    for i in range(len(preds)):
        rmse += (preds[i] - real[i]) ** 2

    rmse = math.sqrt(rmse/len(preds))

    return rmse


def eval_model():
  preds = []
  real = []
  for u in tqdm(np.unique(test_users)):
    user_meta, item_meta, item_image, unseen_image, _, rating_ui = sample_user(
        u, test_users, test_items, test_reviews, test_descriptions, test_categories, test_prices, None, test_ratings)
    user_category, user_price, user_descriptions, user_reviews = user_meta
    i, item_category, item_price, item_description, item_reviews = item_meta

    # transform the input
    latent_rev_user = textCNN_U(user_reviews)
    latent_rev_item = textCNN_I(item_reviews)
    z0 = torch.flatten(torch.cat((latent_rev_user, latent_rev_item), dim=0))
    z_L = transform(z0)
    
    # predict using the transformed input
    latent_desc_i = textCNN_DI(item_description)  # -> 36
    latent_desc_u = textCNN_DU(user_descriptions)  # -> 64

    '''User and Item'''
    latent_uid = mf_u(torch.tensor(u))
    latent_iid = mf_i(i)  # -> 20

    '''Meta data'''
    cat_embed_u, cat_embed_i = F.one_hot(
        user_category, 24), F.one_hot(item_category, 24)  # -> 24

    latent_final = torch.cat(
        (z_L.flatten(),  # 50
         latent_uid,  # 10
         latent_iid,  # 10
         latent_desc_i.flatten(),  # 36
         latent_desc_u.flatten(),  # 64
         cat_embed_u,  # 24
         cat_embed_i,  # 24
         user_price,  # 1
         item_price),  # 1
        dim=0).view(1, 220)

    '''Image - trained seperatedly'''
    # img_result_1, img_result_2 = imageCNN(item_image, unseen_image)
    # s = imageCNN.state_dict()
    # # regularisation term
    # nn_regularisers = 0
    # for k in s:
    #   params = s[k].flatten()
    #   nn_regularisers += torch.sum(params ** 2)/2
    # cost_train = torch.sum(torch.log(F.sigmoid(torch.sum(torch.dot(torch.subtract(img_result_1, img_result_2),latent_final)).flatten())))
    # cost_train -= 0.001 * nn_regularisers
    # loss_image += cost_train

    pred = fm_S(latent_final)
    preds.append(pred)
    real.append(rating_ui)

  return eval_rmse(preds, real)


In [132]:
# needs to be changed

best_trans_params = None
best_source_params = None
best_rmse = np.inf

for epoch in range(0, max_epoch):
  for i in tqdm(range(1000)):
    # batch_size = 50
    users = next(train_iterator)
    users = users.to(device)
    loss_T = 0
    loss_trans = 0
    loss_S = 0
    loss_image = 0

    optimiser_T.zero_grad()
    optimiser_FMT.zero_grad()
    optimiser_I.zero_grad()
    optimiser_U.zero_grad()
    optimiser_trans.zero_grad()
    optimiser_FMS.zero_grad()
    optimiser_DI.zero_grad()
    optimiser_DU.zero_grad()
    optimiser_image.zero_grad()
    optimiser_MFI.zero_grad()
    optimiser_MFU.zero_grad()

    for u in users:
      '''Get data needed'''
      user_meta, item_meta, item_image, unseen_image, rev_ui, rating_ui = sample_user(u, train_users, train_items, train_reviews, train_descriptions, train_categories, train_prices, None, train_ratings)
      user_category, user_price, user_descriptions, user_reviews = user_meta
      i, item_category, item_price, item_description, item_reviews = item_meta

      '''
      Training is as follows:
      - Components to train:
      1. TransNets
      2. SiameseNetwork
      3. Description Extractor
      4. User / Item embedding MFs

      - How they are trained:
      1. TransNets: (Completed)
        0. Principle: User review must be isolated from prediction networks in training time
        1. Train target network predictor fm_T with real rev_ui, backpropagate loss_T
        2. Train transform network (textCNN_U, textCNN_I, transform) by constuction z_L and compare with latent_rev_ui from target network
        3. Train source network predictor fm_S with z_L
      2. SiameseNetwork
        1. Trained Individually by minimising the distance of user's preferred and not preferred images
        2. The result is used to weight the final prediction
      3. Description Extractors
        1. Train with source network by generating latent_desc_u, latent_desc i
        2. Feed latent_desc_u, latent_desc_i into fm_S and backpropagate errors
      4. User / Item meta MFs
        1. Train with source network by generating latent_uid, latent_iid
        2. Feed latent_uid, latent_iid into fm_S and backpropagate errors
      '''


      '''TARGET AND TRANSFORM NETWORKS ARE NOT TO BE AFFECTED BY METADATA'''


      '''Train target network on the actual review'''
      latent_rev_ui = textCNN_T(rev_ui)
      pred_T = fm_T(latent_rev_ui)
      loss_T += l1_loss(rating_ui, pred_T)


      '''THE SECTIONS BELOW ARE USED FOR PREDICTION'''

      '''Learn to transform'''
      latent_rev_user = textCNN_U(user_reviews)
      latent_rev_item = textCNN_I(item_reviews)
      z0 = torch.flatten(torch.cat((latent_rev_user, latent_rev_item), dim=0))
      z_L = transform(z0) # 50
      loss_trans += l2_loss(latent_rev_ui, z_L)

      '''Train a predictor on the transformed input
      This is where the altering begins'''

      '''Description Extractor'''
      latent_desc_i = textCNN_DI(item_description) # -> 36
      latent_desc_u = textCNN_DU(user_descriptions) # -> 64

      '''User and Item'''
      latent_uid = mf_u(u)
      latent_iid = mf_i(i) # -> 20

      '''Meta data'''
      cat_embed_u, cat_embed_i = F.one_hot(user_category, 24), F.one_hot(item_category, 24)  # -> 24

      latent_final = torch.cat(
        (z_L.flatten(), # 50
        latent_uid, # 10
        latent_iid, # 10
        latent_desc_i.flatten(), # 36
        latent_desc_u.flatten(), # 64
        cat_embed_u, # 24
        cat_embed_i, # 24
        user_price, # 1
        item_price), # 1
        dim=0).view(1,220)

      '''Image - trained seperatedly'''
      # img_result_1, img_result_2 = imageCNN(item_image, unseen_image)
      # s = imageCNN.state_dict()
      # # regularisation term
      # nn_regularisers = 0
      # for k in s:
      #   params = s[k].flatten()
      #   nn_regularisers += torch.sum(params ** 2)/2
      # cost_train = torch.sum(torch.log(F.sigmoid(torch.sum(torch.dot(torch.subtract(img_result_1, img_result_2),latent_final)).flatten())))
      # cost_train -= 0.001 * nn_regularisers
      # loss_image += cost_train
      
      pred_S = fm_S(latent_final)
      loss_S += l1_loss(rating_ui, pred_S)

    loss_S /= batch_size
    loss_S.backward(retain_graph=True)
    optimiser_FMS.step()
    optimiser_MFU.step()
    optimiser_MFI.step()
    optimiser_DI.step()
    optimiser_DU.step()
    

    loss_trans /= batch_size
    loss_trans.backward(retain_graph=True)
    optimiser_U.step()
    optimiser_I.step()
    optimiser_trans.step()

    loss_T /= batch_size
    loss_T.backward()
    optimiser_T.step()
    optimiser_FMT.step()

    # independent backward
    # loss_image /= batch_size
    # loss_image.backward()
    # optimiser_image.step()

  with torch.no_grad():
    save_training()
    rmse = eval_model()
    print(f"epoch: [{epoch}/{max_epoch}]: rmse - {rmse}")
    if rmse < best_rmse:
      best_trans_params = transform.parameters()
      best_source_params = fm_S.parameters()


  0%|          | 0/1000 [00:00<?, ?it/s]

torch.Size([64, 64])
torch.Size([320, 64])
torch.Size([384, 64])
torch.Size([36, 36])
torch.Size([216, 36])
torch.Size([64, 64])
torch.Size([128, 64])
torch.Size([64, 64])
torch.Size([36, 36])
torch.Size([108, 36])
torch.Size([64, 64])
torch.Size([128, 64])
torch.Size([128, 64])
torch.Size([36, 36])
torch.Size([108, 36])
torch.Size([64, 64])
torch.Size([128, 64])
torch.Size([832, 64])
torch.Size([36, 36])
torch.Size([108, 36])
torch.Size([64, 64])
torch.Size([192, 64])
torch.Size([64, 64])
torch.Size([36, 36])
torch.Size([144, 36])
torch.Size([64, 64])
torch.Size([128, 64])
torch.Size([192, 64])
torch.Size([36, 36])
torch.Size([108, 36])
torch.Size([64, 64])
torch.Size([128, 64])
torch.Size([64, 64])
torch.Size([36, 36])
torch.Size([108, 36])
torch.Size([64, 64])
torch.Size([384, 64])
torch.Size([832, 64])
torch.Size([36, 36])
torch.Size([252, 36])
torch.Size([64, 64])
torch.Size([192, 64])
torch.Size([320, 64])
torch.Size([36, 36])
torch.Size([144, 36])
torch.Size([64, 64])
torch.Size

  0%|          | 1/1000 [00:00<14:52,  1.12it/s]

torch.Size([64, 64])
torch.Size([256, 64])





KeyboardInterrupt: 