In [1]:
import sys, os
# path to the models I made
sys.path.insert(0, '/projappl/project_2005600/wallpaper/img_clustering')
os.environ['TORCH_HOME'] = '/scratch/project_2005600'

batch_size = int(sys.argv[1])
cluster_num = int(sys.argv[2])
learning_rate = float(sys.argv[3])

print(f'batch size: {sys.argv[1]}')
print(f'batch size: {sys.argv[2]}')
print(f'batch size: {sys.argv[3]}')

In [None]:
import torch
import torch.nn as nn
from torch.nn import functional as F

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# import models I made
from vae import VAE
from delius import DELIUS
from cluster import Cluster
from wallpaper_workshop_dataset import WallpaperWorkshopDataset

In [2]:
from torchvision import transforms
csv_file = '/scratch/project_2005600/wallpaper_jepg_rgb.csv'
root_dir = '/scratch/project_2005600/wallpaper_workshop'
img_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize((224, 224), antialias = True),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

ww_dataset = WallpaperWorkshopDataset(csv_file, root_dir, img_transform)

In [12]:
from torch.utils.data import DataLoader
dataloader = DataLoader(ww_dataset, batch_size = 32, shuffle = True)

In [6]:
# feature extraction model
densenet_model = torch.hub.load('pytorch/vision:v0.10.0', 'densenet121', pretrained=True)

vae_model = VAE(1024)
feature_model = DELIUS(densenet_model.features, vae_model)

# optional: freeze the feature extraction (from a pre-trained model) layers
for params in feature_model.feature_model:
    params.requires_grad = False

feature_model.to(device)

# clustering model
centroids = torch.randn([5, 10], requires_grad = True)
centroids.to(torch.float32)
cluster_model = Cluster(centroids)
# cluster_model = ClusterLayer(5, 10)
cluster_model.to(device)

Using cache found in /scratch/project_2005600/hub/pytorch_vision_v0.10.0


Cluster()

In [7]:
# training utilities/ parameters
epochs = 10
learning_rate = .0001
optimizer = torch.optim.Adam(
    list(feature_model.parameters()) + list(cluster_model.parameters()),
    # cluster_model.parameters(),
    lr = learning_rate)

In [10]:
import gc
gc.collect()

992

In [11]:
# training
n_steps = len(dataloader)
for epoch in range(epochs):
    for i, x in enumerate(dataloader):
        # feature extraction
        # feed forward
        image = x['image'].to(device)
        recons_h, h, z, mu, sigma= feature_model(image)
        
        recons_h = recons_h.to(device)
        h = h.to(device)
        z = z.to(device)
        
        feature_loss = F.mse_loss(recons_h, h)

        # cluster
        # feed forward
        p, q = cluster_model(z)

        p = p.to(device)
        q = q.to(device)
        
        cluster_loss = F.kl_div(torch.log(p), torch.log(q), log_target = True)
        # cluster_loss.backward()
        
        # compute loss and gradients
        total_loss =  feature_loss + cluster_loss
        # total_loss =  cluster_loss
        total_loss.backward()

        # update parameters
        # centroid_old = cluster_model.centroids
        optimizer.step()

        # print status
        print(f'epoch: [{epoch+1}/{epochs}]')
        print(f'[{i+1}/{n_steps}]')
    
        print(f'total loss: {total_loss.item()}')
        print(f'feature loss: {feature_loss.item()}')
        print(f'cluster loss: {cluster_loss.item()}')
        # print (cluster_model.centroids.grad)
        # print (cluster_model.centroids)
            
        optimizer.zero_grad()



epoch: [1/10]
[1/5622]
total loss: 0.5757477283477783
feature loss: 0.566113293170929
cluster loss: 0.009634431451559067
epoch: [1/10]
[2/5622]
total loss: 0.5237880945205688
feature loss: 0.5153955817222595
cluster loss: 0.008392523042857647
epoch: [1/10]
[3/5622]
total loss: 0.4844299852848053
feature loss: 0.4785144329071045
cluster loss: 0.005915550049394369
epoch: [1/10]
[4/5622]
total loss: 0.4579017460346222
feature loss: 0.45247119665145874
cluster loss: 0.005430558230727911
epoch: [1/10]
[5/5622]
total loss: 0.4034084677696228
feature loss: 0.39669930934906006
cluster loss: 0.006709159817546606
epoch: [1/10]
[6/5622]
total loss: 0.37116295099258423
feature loss: 0.36412912607192993
cluster loss: 0.007033830042928457
epoch: [1/10]
[7/5622]
total loss: 0.33919647336006165
feature loss: 0.3321748971939087
cluster loss: 0.007021588273346424
epoch: [1/10]
[8/5622]
total loss: 0.2981194853782654
feature loss: 0.29089945554733276
cluster loss: 0.007220017723739147
epoch: [1/10]
[9/56

KeyboardInterrupt: 