In [1]:
import os
from copy import deepcopy

import torch
import numpy as np

In [2]:
from src.utils.common import DATASET_NAME_MAP
# DATASET_NAME = 'mm_tinyimagenet'
DATASET_NAME = 'mm_cifar10'

def get_held_out_data(nb_samples, transform, is_train):
    train_dataset, test_dataset, _ = DATASET_NAME_MAP[DATASET_NAME]()
    dataset = train_dataset if is_train else test_dataset
    if transform is not None:
        dataset_blurred = deepcopy(dataset)
        dataset_blurred.transform2 = transform
    y_data = np.array(dataset.dataset.targets)
    num_classes = len(np.unique(y_data))
    nb_samples_per_class = nb_samples // num_classes
    idxs = []
    for i in range(num_classes):
        idxs_i = np.where(y_data == i)[0]
        sampled_idxs_i = np.random.choice(idxs_i, size=nb_samples_per_class, replace=False)
        idxs.append(sampled_idxs_i)
        
    idxs = np.concatenate(idxs)
    selected_elements = [dataset[i] for i in idxs]
    x_data, y_data = zip(*selected_elements)
    x_data_left, x_data_right = zip(*x_data)
    
    x_data_left = torch.stack(x_data_left)
    x_data_right = torch.stack(x_data_right)
    y_data = torch.tensor(y_data)
    
    
    selected_elements_blurred = [dataset_blurred[i] for i in idxs]
    x_data_blurred, y_data = zip(*selected_elements_blurred)
    x_data_left_blurred, x_data_right_blurred = zip(*x_data_blurred)
    
    x_data_left_blurred = torch.stack(x_data_left_blurred)
    x_data_right_blurred = torch.stack(x_data_right_blurred)
    y_data = torch.tensor(y_data)
    
    if not os.path.exists('data'):
        os.mkdir('data')
        
    prefix = 'train' if is_train else 'val'
                
    torch.save(x_data_left, f'data/{prefix}_{DATASET_NAME}_held_out_proper_x_left.pt')
    torch.save(x_data_right, f'data/{prefix}_{DATASET_NAME}_held_out_proper_x_right.pt')
    torch.save(y_data, f'data/{prefix}_{DATASET_NAME}_held_out_y.pt')
    
    
    # torch.save(x_data_left_blurred, f'data/{DATASET_NAME}_held_out_blurred_x_left.pt')
    torch.save(x_data_right_blurred, f'data/{prefix}_{DATASET_NAME}_held_out_blurred_x_right.pt')
    # np.save(f'data/{dataset_name}_held_out_x.npy', x_data)
    # np.save(f'data/{dataset_name}_held_out_y.npy', y_data)
    # return x_data_fellow, x_data_amblyopic, y_data

In [3]:
# TRAIN
from src.data.transforms import TRANSFORMS_BLURRED_RIGHT_NAME_MAP
transform_blurred = TRANSFORMS_BLURRED_RIGHT_NAME_MAP[DATASET_NAME](0.0)

get_held_out_data(nb_samples=1000, transform=transform_blurred, is_train=True)

Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified


In [4]:
# VAL
from src.data import transforms_cifar10
transform_blurred = transforms_cifar10.TRANSFORMS_NAME_MAP['transform_eval_blurred'](32, 32, 1/4, 0.0)

get_held_out_data(nb_samples=1000, transform=transform_blurred, is_train=False)

Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified


In [None]:
DATASET_NAME = 'mm_cifar10'

In [None]:
proper_x_left = torch.load(f'data/{DATASET_NAME}_held_out_proper_x_left.pt')
proper_x_right = torch.load(f'data/{DATASET_NAME}_held_out_proper_x_right.pt')
blurred_x_right = torch.load(f'data/{DATASET_NAME}_held_out_blurred_x_right.pt')

# Save indices for corruption experiment

In [4]:
import numpy as np

from src.utils.common import DATASET_NAME_MAP

In [7]:
DATASET_NAME = 'mm_cifar10'
train_dataset, _, _ = DATASET_NAME_MAP[DATASET_NAME]()
y_data = np.array(train_dataset.dataset.targets)
nb_samples = int(0.0 * y_data.shape[0])
num_classes = len(np.unique(y_data))
nb_samples_per_class = nb_samples // num_classes
idxs = []
for i in range(num_classes):
    idxs_i = np.where(y_data == i)[0]
    sampled_idxs_i = np.random.choice(idxs_i, size=nb_samples_per_class, replace=False)
    idxs.append(sampled_idxs_i)
    
idxs = np.concatenate(idxs)

Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified


In [9]:
np.save(f'data/{DATASET_NAME}_subset_000.npy', idxs)

# Coś innego

In [9]:
DATASET_NAME = 'mm_cifar10'
def get_indices(nb_samples):
    train_dataset, _, _ = DATASET_NAME_MAP[DATASET_NAME]()
    y_data = np.array(train_dataset.dataset.targets)
    num_classes = len(np.unique(y_data))
    nb_samples_per_class = nb_samples // num_classes
    idxs = []
    for i in range(num_classes):
        idxs_i = np.where(y_data == i)[0]
        sampled_idxs_i = np.random.choice(idxs_i, size=nb_samples_per_class, replace=False)
        idxs.append(sampled_idxs_i)
        
    idxs = np.concatenate(idxs)
    y_data = torch.tensor(idxs)
    
    if not os.path.exists('data'):
        os.mkdir('data')
    torch.save(y_data, f'data/train_{DATASET_NAME}_indices.pt')
    
get_indices(nb_samples=10000)

Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified


In [None]:
# x_data_proper = [transform_train_proper(x) for x in x_data]
# x_data_proper = torch.stack(x_data_proper)
# y_data_proper = torch.from_numpy(y_data)

# torch.save(x_data_proper, f'data/{DATASET_NAME}_held_out_proper_x.pt')
# torch.save(y_data_proper, f'data/{DATASET_NAME}_held_out_proper_y.pt')

In [None]:
# x_data_blurred = [transform_blurred(x) for x in x_data]
# x_data_blurred = torch.stack(x_data_blurred)
# y_data_blurred = torch.from_numpy(y_data)

# torch.save(x_data_blurred, f'data/{DATASET_NAME}_held_out_blurred_x.pt')
# torch.save(y_data_blurred, f'data/{DATASET_NAME}_held_out_blurred_y.pt')

In [None]:
description = 'proper'
path = f'data/{DATASET_NAME}_held_out_{description}_x_fellow.pt'
torch.load(path)

In [None]:
import torchvision
import numpy as np

import matplotlib.pyplot as plt
def show(img):
    npimg = img.numpy()
    plt.figure(figsize=(20,20))
    plt.imshow(np.transpose(npimg, (1,2,0)), interpolation='nearest')


In [None]:
grid = torchvision.utils.make_grid(proper_x_left[:16])
show(grid)

In [None]:
grid = torchvision.utils.make_grid(proper_x_right[:16])
show(grid)

In [None]:
grid = torchvision.utils.make_grid(blurred_x_right[:16])
show(grid)


In [None]:
import matplotlib.pyplot as plt

def plot_histogram(data, bins=10):
    plt.hist(data, bins=bins)
    plt.show()

In [None]:
idx = 10
plot_histogram(torch.flatten(proper_x_right[idx]))
plot_histogram(torch.flatten(blurred_x_right[idx]))
print((proper_x_right[idx].abs() - blurred_x_right[idx].abs()).sum(), (proper_x_right[idx] - blurred_x_right[idx]).sum()), 
print(proper_x_right[idx].abs().sum(), proper_x_right[idx].sum(), (proper_x_right[idx] > 0).float().mean()) 
print(blurred_x_right[idx].abs().sum(), blurred_x_right[idx].sum(), (blurred_x_right[idx] > 0).float().mean())
print((proper_x_right[idx].abs() > blurred_x_right[idx].abs()).float().mean())

In [10]:
from src.utils.utils_model import load_model_specific_params
from src.utils.prepare import prepare_model

def load_model(model_name, checkpoint_path):
    model_params = load_model_specific_params(model_name)
    model_params = {
        'num_classes': 10,
        'input_channels': 3,
        'img_height': 32,
        'img_width': 32,
        'overlap': 00,
        **model_params
        }
        
    model = prepare_model(model_name, model_params=model_params)
    # model.load_state_dict(torch.load(checkpoint_path, map_location=torch.device('cpu')))
    return model

# models = []
# epochs = [0,20,40,60,80,120,160]

model_name = 'mm_resnet'

model = load_model(model_name, checkpoint_path=None)

In [11]:
model

ResNet(
  (conv11): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(2, 2), bias=False)
  (maxpool1): Identity()
  (conv21): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(2, 2), bias=False)
  (maxpool2): Identity()
  (left_branch): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(2, 2), bias=False)
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): Identity()
    (4): Sequential(
      (0): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu1): ReLU(inplace=True)
        (relu2): 

# Change activations

In [12]:
import torch.nn as nn

def change_activation(model, old_activation, new_activation):
    for name, module in model.named_children():
        if isinstance(module, old_activation):
            setattr(model, name, new_activation())
        else:
            change_activation(module, old_activation, new_activation)

# użycie
change_activation(model, nn.ReLU, nn.LeakyReLU)

In [13]:
model

ResNet(
  (conv11): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(2, 2), bias=False)
  (maxpool1): Identity()
  (conv21): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(2, 2), bias=False)
  (maxpool2): Identity()
  (left_branch): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(2, 2), bias=False)
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): LeakyReLU(negative_slope=0.01)
    (3): Identity()
    (4): Sequential(
      (0): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): LeakyReLU(negative_slope=0.01)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu1): LeakyReLU(ne

In [14]:
model(torch.randn(10,3,32,32), torch.randn(10,3,32,32))

In [None]:
from src.modules.aux_modules import DeadReLU
module = DeadReLU(model.left_branch, is_left_branch=True)
module.enable()

In [None]:
x_true1, x_true2 = proper_x_left[:4], proper_x_right[:4]
model(x_true1, x_true2)

In [None]:
module.nb_of_dead_relu

In [None]:
list(model.left_branch.named_modules())

In [None]:
proper_x_left[:4].shape