In [1]:
import lightning as L
import torch

from torch.utils.data import DataLoader
from easydict import EasyDict as edict
from datasets import get_dataloaders
args = edict()
args.dataset = "3dshapes"
args.sub_dataset = "composition"
args.make_random_batches = True
args.train_method = "linear"
args.pretrained_reps = False
args.pretrained_encoder = False
args.encoder = edict()
args.encoder.arch = "cnn"
args.train_bs = 256
args.num_workers=0
args.accum_batches = 1
args.num_steps = 100000
args.lr = 0.001

There was a problem when trying to write in your cache folder (/storage/cache). You should set the environment variable TRANSFORMERS_CACHE to a writable directory.


In [41]:
import torch.nn as nn

from model_info import BigEncoder

class LightningLinear(L.LightningModule):
    def __init__(self, args, encoder, hidden_dim=128, latent_dim=5, **kwargs):
        super().__init__()        
        self.args = args
        self.encoder = encoder
        self.directions = nn.Parameter(torch.randn(latent_dim, hidden_dim))
        self.bias = nn.Parameter(torch.randn(1,hidden_dim))
        self.latent_dim = latent_dim

    def encode(self, x):
        return self.encoder(x)

    def modulate(self, x, delta=None):
        if delta is None:
            bs = x.shape[0]
            device = x.device
            delta = torch.zeros((bs,self.latent_dim),device=device)
        delta_dirs = delta.unsqueeze(-1)*self.directions

        return x + delta_dirs
        
    def forward(self, x):
        x = self.encode(x)
        return x

    def linear_model_loss(self, reps, latents):
        bias = self.bias
        dirs = self.directions
        delta_dirs =  latents.unsqueeze(-1)*dirs
        reconstructed_reps = bias + delta_dirs.sum(dim=1)
        loss = ((reps - (reconstructed_reps))**2).mean() 
        return loss

    def log_metrics_split(self, metrics, split):
        metrics = {f'{split}_{k}': v for k,v in metrics.items()}
        self.log_dict({k: v.item() for k, v in metrics.items()}, on_step=True, on_epoch=True, prog_bar=True, add_dataloader_idx=False)
        return metrics # so lightning can train
    
    def step(self, batch):
        imgs, _, latents = batch
        metrics = dict()
        reps = self.encode(imgs.float())
        loss = self.linear_model_loss(reps, latents.float())
        metrics = {'loss': loss, 'linear_loss': loss}
        return metrics 
        
    def training_step(self, batch, batch_idx):
        split = "train"
        metrics = self.step(batch)
        metrics = self.log_metrics_split(metrics, split)
        return metrics[f'{split}_loss']
      
    def validation_step(self, batch, batch_idx, dataloader_idx=0):
        split = "val"
        metrics = self.step(batch)
        metrics = self.log_metrics_split(metrics, split)
        return metrics[f'{split}_loss']
        
    def test_step(self, batch, batch_idx, dataloader_idx=0):
        split = "test"
        metrics = self.step(batch)
        metrics = self.log_metrics_split(metrics, split)
        return metrics[f'{split}_loss']
    
    def configure_optimizers(self):        
        params = []

        if hasattr(self, 'encoder') and self.encoder is not None:
            params += list(self.encoder.parameters())
        if hasattr(self, 'modulator') and self.modulator is not None:
            params += list(self.modulator.parameters())
        if hasattr(self, 'regressor') and self.regressor is not None:
            params += list(self.regressor.parameters())
        if hasattr(self, 'decoder') and self.decoder is not None:
            params += list(self.decoder.parameters())
        
        param_groups = [{'params': params}]

        return torch.optim.AdamW(param_groups, lr=self.args.lr)

In [42]:
print("Creating model!", flush=True)
encoder = BigEncoder(args)
model = LightningLinear(args, encoder=encoder, hidden_dim=256, latent_dim=6)

Creating model!


In [4]:
print("Loading dataloaders...", flush=True)
dls = get_dataloaders(args)

Loading dataloaders...
Loading 3dshapes dataset...


In [55]:
trainer = L.Trainer(accelerator="gpu",
                    devices=1,
                    enable_progress_bar=True,
                    accumulate_grad_batches=args.accum_batches,
                    check_val_every_n_epoch=None,
                    max_steps=args.num_steps,
                    val_check_interval=10000,
                    callbacks=[],
                    logger=[])
trainer.fit(model=model,
            train_dataloaders=dls['train'],
            val_dataloaders=[dls['val']],
            ckpt_path = None)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name         | Type       | Params
--------------------------------------------
0 | encoder      | BigEncoder | 381 K 
  | other params | n/a        | 1.8 K 
--------------------------------------------
383 K     Trainable params
0         Non-trainable params
383 K     Total params
1.532     Total estimated model params size (MB)
SLURM auto-requeueing enabled. Setting signal handlers.


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_steps=100000` reached.


In [44]:
batch = next(iter(dls['train']))

In [46]:
imgs, _, latents = batch

In [36]:
a = torch.randn((256, 256)) 
b=  torch.randn(256, 6, 256)
(a+b).shape

RuntimeError: The size of tensor a (256) must match the size of tensor b (6) at non-singleton dimension 1

In [37]:
b.sum(dim=1).shape

torch.Size([256, 256])

In [49]:
model(imgs.float().cuda())

tensor([[-11.7509,   1.6815,  -4.6021,  ...,  -4.8339,  16.6794,   5.3639],
        [-16.7556,  -4.9350, -12.2518,  ...,   3.1550,  26.3667,  -3.7114],
        [ -6.3644,  -8.6171,  -2.7505,  ...,   1.7719,   0.4567,  -0.4432],
        ...,
        [ -4.6241,  -9.5187,  -8.6548,  ...,   0.8589,  -5.1543,   7.2975],
        [ -9.8387,  -1.2096, -11.7213,  ...,   1.3912,  17.7150,   0.0394],
        [-12.0111,  -4.0808,   0.1939,  ...,  -2.1665,  14.3975,   0.3059]],
       device='cuda:0', grad_fn=<AddmmBackward0>)

In [None]:
from model_info import LinearModulator

m = LinearModulator( input_dim=256, hidden_dim=256, latent_dim=6)

In [None]:
import torch
data=torch.randn(10, 256)

m(data)