# StaTexNet - Network Encoding Statistics for Textures
Locally Enforced Connections

## Dependencies & Hyperparameters

In [1]:
import torch
import torchvision
import torchvision.transforms as transforms
import torch.utils.data as data
from torch.utils.data import DataLoader
import torch.nn as nn
import sys
import matplotlib.pyplot as plt
import numpy as np
from pytorch_metric_learning import losses

import utils.statnetencoder as sne
import imp

#!wget https://raw.githubusercontent.com/pytorch/pytorch/master/torch/utils/collect_env.py

#sys.path.append('../')
import steerable
import steerable.utils as utils
from steerable.SCFpyr_PyTorch import SCFpyr_PyTorch

torch.manual_seed(17)

#use GPU 2
device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')

#hyperparams
num_epochs = 10
batch_size = 10
crop_size = 128
num_stats = 200
#optimizer_type='adam'
optimizer_type='sgd'
learning_rate = 0.01
num_crops = 5

#dataset location
dtd_folder = '~/data/dtd/'

In [2]:
A = torch.randn(2, 10, 10)
upsample = nn.Upsample(size=24, mode='bilinear')
print(A.shape)
A = torch.unsqueeze(A, 0)
print(A.shape)
A = upsample(A)
print(A.shape)

torch.Size([2, 10, 10])
torch.Size([1, 2, 10, 10])
torch.Size([1, 2, 24, 24])


## Define DataLoader

In [3]:
loading_transforms = torchvision.transforms.Compose([#transforms.CenterCrop(size=300),
                                                    #transforms.RandomRotation(degrees=180),
                                                    transforms.Grayscale(),
                                                    #transforms.TenCrop(size=crop_size),
                                                    #transforms.RandomRotation(degrees=[0,90,180,270]),
                                                    transforms.RandomVerticalFlip(p=0.5),
                                                    transforms.RandomHorizontalFlip(p=0.5),
                                                    transforms.FiveCrop(size=crop_size),
                                                    #transforms.functional.vflip(),
                                                    #transforms.functional.hflip(),
                                                    transforms.Lambda(lambda crops: torch.stack([transforms.PILToTensor()(crop) for crop in crops])),
                                                    transforms.ConvertImageDtype(torch.float32)])
                                                    #transforms.PILToTensor()])

#use training set for now
dtd_dataset = torchvision.datasets.DTD(root='~/data/dtd_torch', split='train', partition=1, 
                                       transform=loading_transforms, target_transform=None,
                                       download=True)

sampler = data.RandomSampler(dtd_dataset)

dtd_dataloader = DataLoader(dtd_dataset, 
                            sampler=sampler,
                            batch_size=batch_size, 
                            shuffle=False)
#dtd_labels = tf

tensor2pil_transform = transforms.ToPILImage()

## Test Dataloader

In [4]:
# for n, texture_batch in enumerate(dtd_dataloader):
#     #grab texture batch and generate matching labels
#     output = texture_batch[0].to(device)
#     output = torch.flatten(output, start_dim=0, end_dim=1)
#     texture_labels = torch.repeat_interleave(torch.arange(batch_size),num_crops)
#     #apply random permutation
#     perm = torch.randperm(batch_size * num_crops)
#     output = output[perm]
#     texture_labels = texture_labels[perm]
#     print(texture_labels)
#     #loop through batch and plot images
#     for j in range(batch_size):
#         plt.figure(figsize=(8,4))
#         for i in range(num_crops):
#             plt.subplot(2,5,i+1)
#             plt.imshow(tensor2pil_transform(output[i+j*num_crops,:,:,:]))
#             plt.axis('off')
#         plt.show()
#     if(n==1):
#         break;
    
#tensor2pil_transform(output[4,0,:,:,:])

## Define Model & Optimizer

In [5]:
imp.reload(sne)
statnet_model = sne.StatNetEncoder(img_size=(crop_size,crop_size),
                                   batch_size=batch_size,
                                   num_stats=num_stats,
                                   vectorized=False,
                                   device=device)
statnet_model.to(device)

#optimizer
if(optimizer_type=='sgd'):
    optimizer = torch.optim.SGD(statnet_model.parameters(), lr=learning_rate)#, momentum=learning_momentum)
elif(optimizer_type=='adam'):
    optimizer = torch.optim.Adam(statnet_model.parameters(), lr=learning_rate)
elif(optimizer_type=='adagrad'):
    optimizer = torch.optim.Adagrad(statnet_model.parameters(), lr=learning_rate)
elif(optimizer_type=='adadelta'):
    optimizer = torch.optim.Adadelta(statnet_model.parameters(), lr=learning_rate)
else:
    print('No Optimizer Specified! Adam is default!')
    optimizer = torch.optim.Adam(statnet_model.parameters(), lr=learning_rate)

height of pyramid is, 5


## Run Training

In [7]:
imp.reload(sne)
loss_func = losses.GeneralizedLiftedStructureLoss()

training_loss = []
statnet_model.train() # Set model to training mode
optimizer.zero_grad()
statnet_model.zero_grad()
print('Starting Training:')
for i, epoch in enumerate(range(num_epochs)):
    for j, texture_batch in enumerate(dtd_dataloader):
        #grab texture batch and generate matching labels
        output = texture_batch[0].to(device)
        output = torch.flatten(output, start_dim=0, end_dim=1)
        texture_labels = torch.repeat_interleave(torch.arange(batch_size),num_crops)
        #apply random permutation
        perm = torch.randperm(batch_size * num_crops)
        output = output[perm]
        texture_labels = texture_labels[perm]

        #calculate stats
        stats_vector = statnet_model.encode(output)
        #loss definitions
        print(stats_vector.shape)
        loss = loss_func(stats_vector, texture_labels)
        optimizer.step()
        optimizer.zero_grad()
        loss.backward()

        print('*',end='')
        #training_loss.append(loss.item())      
        if(j%30==0):
            print(loss.item())
        if(j==100):
            break;
            
        training_loss.append(loss.item())

    print(f'Finished Epoch {i}. Loss at {loss}.')
    # torch.save(
    #         {'state_dict': statnet_model.state_dict(),
    #         'optimizer_state_dict': optimizer.state_dict()},
    #         f'{model_save_folder}/model_checkpoint_epoch_{i}.pth')
    
print('All Done!')

Starting Training:
height of pyramid is, 5
torch.Size([50, 1, 128, 128, 2])
torch.Size([50, 1, 128, 128, 2])
torch.Size([50, 1, 128, 128, 2])
torch.Size([50, 1, 128, 128, 2])
torch.Size([50, 1, 64, 64, 2])
torch.Size([50, 1, 64, 64, 2])
torch.Size([50, 1, 64, 64, 2])
torch.Size([50, 1, 64, 64, 2])
torch.Size([50, 1, 32, 32, 2])
torch.Size([50, 1, 32, 32, 2])
torch.Size([50, 1, 32, 32, 2])
torch.Size([50, 1, 32, 32, 2])
level pyr torch.Size([50, 26, 128, 128])
torch.Size([50, 26, 128, 128])
torch.Size([50, 200, 121, 121])


ValueError: embeddings must be a 2D tensor of shape (batch_size, embedding_size)

In [None]:
plt.plot(training_loss)

In [None]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = ",".join([str(gpu) for gpu in args.gpus])

print("Available/CUDA_VISIBLE_DEVICES", os.environ["CUDA_VISIBLE_DEVICES"])


In [None]:
n#steerable pyramid

im_batch = im_batch.to(device).float()

pyr_torch = SCFpyr_PyTorch(pyr_height, pyr_nbands, device=device)
coeff_torch = pyr_torch.build(im_batch)

In [None]:
output.shape

In [None]:
?list.insert