In [1]:
import torch
import torchvision
from torchvision.transforms import v2
import matplotlib.pyplot as plt

from utils.diffeo_container import sparse_diffeo_container
from utils.get_model_activation import retrieve_layer_activation, get_flatten_children

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
print(f'Using {device} for inference')

Using cpu for inference


In [2]:
## Clark's scrambled label

import sys
sys.path.append('/scratch/cm6627/diffeo_cnn/experiment/006_RandomLabels/fitting-random-labels')

import torch
import model_wideresnet 

class ModelWeights:
    
    path = '/scratch/cm6627/diffeo_cnn/experiment/006_RandomLabels/ModelWeights/'
    EpochsAmount = [0, 60, 120, 180, 240]
    CorruptAmount = [0.0, 0.5, 1.0]

    @staticmethod
    def load_Model(corrupt: float, epochs: int) -> 'torch.model':
        ### Checks
        if corrupt not in ModelWeights.CorruptAmount:
            raise ValueError(f'`corrupt` must be: {ModelWeights.CorruptAmount}')
        if epochs not in ModelWeights.EpochsAmount:
            raise ValueError(f'`epochs` must be {ModelWeights.EpochsAmount}')

        ### Code
        epochs = str(int(epochs))
        if corrupt == 0.0:
            corrupt = '0p0'
        elif corrupt == 0.5:
            corrupt = '0p5'
        elif corrupt == 1.0:
            corrupt = '1p0'

        if epochs == 0:  # This is just a randomly initalized model 
            corrupt = '0p0'

        file_name = f'/Corrupt-{corrupt}/ModelWeights_{epochs}Epochs.pth'

        # I trained on these parameters, which are default to the paper's code
        depth = 28
        classes = 10
        widen_factor = 1
        drop_rate = 0
        model = model_wideresnet.WideResNet(depth, classes,
                                            widen_factor,
                                            drop_rate=drop_rate)

        model_weights_path = ModelWeights.path + file_name
        model.load_state_dict(torch.load(model_weights_path, map_location=device))
        
        return model

EpochsAvaliable = [0, 60, 120, 180, 240]
CorruptOptions = [0.0, 0.5, 1.0]

In [17]:
normalize = v2.Normalize(mean=[x/255.0 for x in [125.3, 123.0, 113.9]],
                                     std=[x / 255.0 for x in [63.0, 62.1, 66.7]])
inference_trans = v2.Compose([v2.ToImage(), v2.ToDtype(torch.float32, scale=True), normalize])

cifar_data = torchvision.datasets.CIFAR10('/vast/xj2173/diffeo/CIFAR10',train = False, download = True, transform=inference_trans)
data_loader = torch.utils.data.DataLoader(cifar_data,
                                          batch_size=50,
                                          shuffle=False)

Files already downloaded and verified


In [18]:
model = ModelWeights.load_Model(corrupt = CorruptOptions[1], epochs= EpochsAvaliable[3])
model = model.to(device)

In [19]:
diffeos = sparse_diffeo_container(32, 32)
diffeo_strength_list = [0.001, 0.01, 0.025, 0.05, 0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45, 0.5]
for strength in diffeo_strength_list:
    diffeos.sparse_AB_append(3,3,3,strength,20)
diffeos.get_all_grid()

In [20]:
data, _ = next(iter(data_loader))

In [33]:
data.shape[1:] + (20)

TypeError: can only concatenate tuple (not "int") to tuple

In [7]:
for data, label in data_loader:
    pass

In [21]:
deformed = diffeos(data.unsqueeze(1).expand(-1,20,-1,-1,-1))

In [30]:
(10,20,) + (30,)

(10, 20, 30)

In [23]:
deformed.shape

torch.Size([50, 13, 20, 3, 32, 32])

In [27]:
retrieve_layer_activation(model, deformed.reshape(-1, 3, 32, 32), [2,3,4,7])

{'2': tensor([[[[ 2.2376e+00,  1.2454e+00,  3.2706e+00,  ...,  4.2580e-01,
             1.5862e+00,  2.8904e+00],
           [ 5.6868e+00, -1.9659e+00, -4.5033e-01,  ..., -2.2985e+00,
            -3.4908e-01,  2.6950e-01],
           [ 5.9799e+00, -1.5727e+00, -2.4263e+00,  ..., -4.0543e+00,
            -6.8707e-01, -1.7483e+00],
           ...,
           [ 9.2120e+00,  1.4266e+00, -7.6218e+00,  ..., -2.1610e+01,
            -3.5780e+00,  4.3506e+00],
           [ 1.0309e+01,  5.6874e+00, -3.3727e+00,  ..., -2.5150e+01,
            -7.6375e+00,  6.3276e+00],
           [ 1.0147e+01,  4.6017e+00, -5.7183e+00,  ..., -2.0137e+01,
            -6.5471e+00,  3.3222e+00]],
 
          [[ 2.6096e+00,  1.2955e+01,  7.5178e+00,  ...,  5.3853e+00,
             8.3430e+00,  9.0695e+00],
           [-8.7266e+00, -4.5370e+00,  5.9934e-01,  ..., -3.0743e+00,
             2.5258e-01,  2.4342e+00],
           [ 1.1820e+00, -6.4348e-01, -5.2883e+00,  ..., -1.4946e+00,
             1.2826e+00,  2.0876e+

In [28]:
_.keys

<function dict.keys>

In [29]:
224/32

7.0

In [30]:
20 * 13 * 2500

650000

In [7]:
from torchinfo import summary
summary(model, [650000,3,32,32])

: 