<a href="https://colab.research.google.com/github/avyaymc/Convolutional-Visual-Prompts/blob/main/wideresnet.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

#imports

In [None]:
!pip install torch torchvision




In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import models
from torchvision.transforms import Compose, Resize, ToTensor, Normalize

In [None]:

import torch
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, random_split

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

!pip install tqdm



#wideresnet definition

In [None]:
import torch
import torch.nn as nn

class BasicBlock(nn.Module):
    def __init__(self, in_planes, out_planes, stride, drop_rate=0.0):
        super(BasicBlock, self).__init__()
        self.bn1 = nn.BatchNorm2d(in_planes)
        self.relu1 = nn.ReLU(inplace=True)
        self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
                               padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_planes)
        self.relu2 = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(out_planes, out_planes, kernel_size=3, stride=1,
                               padding=1, bias=False)
        self.drop_rate = drop_rate
        self.equalInOut = (in_planes == out_planes)
        self.convShortcut = (not self.equalInOut) and nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride,
                               padding=0, bias=False) or None

    def forward(self, x):
        if not self.equalInOut:
            x = self.relu1(self.bn1(x))
            out = self.conv1(x)
        else:
            out = self.conv1(self.relu1(self.bn1(x)))
        if self.drop_rate > 0:
            out = F.dropout(out, p=self.drop_rate, training=self.training)
        out = self.conv2(self.relu2(self.bn2(out)))
        if not self.equalInOut:
            return torch.add(self.convShortcut(x), out)
        else:
            return torch.add(x, out)


In [None]:
class NetworkBlock(nn.Module):
    def __init__(self, nb_layers, in_planes, out_planes, block, stride, drop_rate=0.0):
        super(NetworkBlock, self).__init__()
        self.layer = self._make_layer(block, in_planes, out_planes, nb_layers, stride, drop_rate)

    def _make_layer(self, block, in_planes, out_planes, nb_layers, stride, drop_rate):
        layers = []
        for i in range(nb_layers):
            layers.append(block(i == 0 and in_planes or out_planes, out_planes, i == 0 and stride or 1, drop_rate))
        return nn.Sequential(*layers)

    def forward(self, x):
        return self.layer(x)


In [None]:
class WideResNet(nn.Module):
    def __init__(self, depth, widen_factor, num_classes, drop_rate=0.0):
        super(WideResNet, self).__init__()
        n_channels = [16, 16 * widen_factor, 32 * widen_factor, 64 * widen_factor]
        assert (depth - 4) % 6 == 0, 'Wide-resnet depth should be 6n+4'
        n = (depth - 4) // 6
        block = BasicBlock
        self.conv1 = nn.Conv2d(3, n_channels[0], kernel_size=3, stride=1, padding=1, bias=False)
        self.block1 = NetworkBlock(n, n_channels[0], n_channels[1], block, 1, drop_rate)
        self.block2 = NetworkBlock(n, n_channels[1], n_channels[2], block, 2, drop_rate)
        self.block3 = NetworkBlock(n, n_channels[2], n_channels[3], block, 2, drop_rate)
        self.bn1 = nn.BatchNorm2d(n_channels[3])
        self.relu = nn.ReLU(inplace=True)
        self.fc = nn.Linear(n_channels[3], num_classes)
        self.n_channels = n_channels[3]

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()
            elif isinstance(m, nn.Linear):
                m.bias.data.zero_()

    def forward(self, x):
        out = self.conv1(x)
        out = self.block1(out)
        out = self.block2(out)
        out = self.block3(out)
        out = self.relu(self.bn1(out))
        out = F.avg_pool2d(out, 8)
        out = out.view(-1, self.n_channels)
        return self.fc(out)


#data

In [None]:
import numpy as np
from torchvision.transforms import ToTensor
from torch.utils.data import Dataset

class CIFAR10C(Dataset):
    def __init__(self, corruption_npy, labels_npy, transform=None):
        self.images = np.load(corruption_npy)
        self.labels = np.load(labels_npy)
        self.transform = transform

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image = self.images[idx]
        label = self.labels[idx]

        if self.transform:
            image = self.transform(image)

        return image, label

In [None]:

import os
from google.colab import drive
drive.mount('/content/drive')
data_folder = "/content/drive/MyDrive/CIFAR-10-C/"

labels_npy = f"{data_folder}/labels.npy"


corruption_file = "fog.npy"
corruption_npy = f"{data_folder}/{corruption_file}"
corruption_name = corruption_file[:-4]
datasets = {corruption_name: CIFAR10C(corruption_npy, labels_npy, transform=ToTensor())}


Mounted at /content/drive


#models

In [None]:
from google.colab import drive
import math
drive.mount('/content/drive')

pretrained_model = WideResNet(depth=28, widen_factor=10, num_classes=10)


checkpoint_path = "/content/drive/MyDrive/cifar10_standard.pth"
checkpoint = torch.load(checkpoint_path)

state_dict = checkpoint['state_dict']

state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}

pretrained_model.load_state_dict(state_dict)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
pretrained_model = pretrained_model.to(device)

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
class WideResNetFeatures(nn.Module):
    def __init__(self, model, layer_index):
        super(WideResNetFeatures, self).__init__()
        self.layer_index = layer_index

        self.conv1 = model.conv1
        self.block1 = model.block1
        self.block2 = model.block2
        self.block3 = model.block3

    def forward(self, x):
        out = self.conv1(x)
        if self.layer_index >= 1:
            out = self.block1(out)
        if self.layer_index >= 2:
            out = self.block2(out)
        if self.layer_index >= 3:
            out = self.block3(out)
        return out


In [None]:
layer_index = 2
pretrained_wrn_features = WideResNetFeatures(pretrained_model, layer_index)


In [None]:
import torch.nn.functional as F
def contrastive_loss(z, y, tau=0.1, epsilon=1e-8):
    z = F.normalize(z, dim=1)
    cos_sim = torch.matmul(z, z.T) / (torch.norm(z, dim=1).unsqueeze(1) * torch.norm(z, dim=1).unsqueeze(0) + epsilon)

    mask = (y.unsqueeze(1) == y.unsqueeze(0)).float()
    mask.fill_diagonal_(0)

    cos_sim_max = torch.max(cos_sim, dim=1, keepdim=True)[0]
    cos_sim_exp = torch.exp((cos_sim - cos_sim_max) / tau)
    cos_sim_exp_sum = cos_sim_exp.sum(1) - torch.diag(cos_sim_exp)
    cos_sim_exp_sum_positive = torch.sum(cos_sim_exp * mask, dim=1)

    loss = torch.log(cos_sim_exp_sum + epsilon) - torch.log(cos_sim_exp_sum_positive + epsilon)
    loss = loss.mean()

    return loss


In [None]:
class ConvKernel(nn.Module):
    def __init__(self, kernel_size, in_channels, out_channels, projection_dim):
        super(ConvKernel, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride=1, padding=1)
        nn.init.kaiming_normal_(self.conv.weight, mode='fan_out', nonlinearity='relu')
        self.bn = nn.BatchNorm2d(out_channels)
        self.linear_input_size = out_channels * feature_shape[1] * feature_shape[2]
        self.linear = nn.Linear(self.linear_input_size, projection_dim)
        nn.init.xavier_normal_(self.linear.weight)

    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        x = x.view(x.size(0), -1)
        x = self.linear(x)
        return x

#training

In [None]:
from tqdm import tqdm

def train_conv_kernel(dataloader, model, conv_kernel_module, device, num_epochs, learning_rate, layer_index, contrastive_loss_fn, accumulation_steps=4):
    optimizer = torch.optim.SGD(conv_kernel_module.parameters(), lr=learning_rate, momentum=0.9, weight_decay=1e-5)

    for epoch in range(num_epochs):
        running_loss = 0.0

        progress_bar = tqdm(enumerate(dataloader), total=len(dataloader), desc=f"Epoch {epoch+1}/{num_epochs}")

        for i, data in progress_bar:
            inputs, labels = data
            inputs = inputs.to(device)
            labels = labels.to(device)

            optimizer.zero_grad()

            features = model(inputs)

            prompted_features = conv_kernel_module(features)

            loss = contrastive_loss_fn(prompted_features, labels)

            loss.backward()
            torch.nn.utils.clip_grad_norm_(conv_kernel_module.parameters(), 1.0)
            optimizer.step()

            running_loss += loss.item()
            progress_bar.set_postfix({"loss": running_loss / (i + 1)})

        print(f'Epoch [{epoch + 1}/{num_epochs}], Loss: {running_loss / (i + 1):.4f}')



In [None]:
with torch.no_grad():
    sample_input = torch.randn(1, 3, 32, 32).to(device)
    sample_output = pretrained_wrn_features(sample_input)
    feature_shape = sample_output.shape[1:]
    feature_dim = torch.prod(torch.tensor(feature_shape)).item()
    print("Feature shape:", feature_shape)
    print("Feature dim:", feature_dim)

# if layer_index == 0:
#     layer_input_channels = 3
# elif layer_index == 1:
#     layer_input_channels = pretrained_wrn_features.block1.layer[-1].conv2.out_channels
# elif layer_index == 2:
#     layer_input_channels = pretrained_wrn_features.block2.layer[-1].conv2.out_channels
# elif layer_index == 3:
#     layer_input_channels = pretrained_wrn_features.block3.layer[-1].conv2.out_channels

layer_input_channel_map = {
    0: 3,
    1: pretrained_wrn_features.block1.layer[-1].conv2.out_channels,
    2: pretrained_wrn_features.block2.layer[-1].conv2.out_channels,
    3: pretrained_wrn_features.block3.layer[-1].conv2.out_channels
}

layer_input_channels = layer_input_channel_map.get(layer_index)



Feature shape: torch.Size([320, 16, 16])
Feature dim: 81920


In [None]:
projection_dim = 64
conv_kernel_module = ConvKernel(kernel_size=3, in_channels=layer_input_channels, out_channels=layer_input_channels, projection_dim=projection_dim).to(device)


In [None]:
batch_size = 25
corrupted_dataloader = torch.utils.data.DataLoader(datasets['fog'], batch_size=batch_size, shuffle=True, num_workers=2)
num_epochs = 10
learning_rate = 0.0001
# Load the entire dataset
full_dataset = CIFAR10C(corruption_npy, labels_npy, transform=ToTensor())

# Create a subset for severity 1
severity1_dataset = torch.utils.data.Subset(full_dataset, range(10000))

# Create a DataLoader for the subset
severity1_dataloader = torch.utils.data.DataLoader(severity1_dataset, batch_size=batch_size, shuffle=True, num_workers=2)

train_conv_kernel(severity1_dataloader, pretrained_wrn_features, conv_kernel_module, device, num_epochs, learning_rate, layer_index, contrastive_loss)


Epoch 1/10: 100%|██████████| 400/400 [00:37<00:00, 10.77it/s, loss=3.26]

Epoch [1/10], Loss: 3.2599



Epoch 2/10: 100%|██████████| 400/400 [00:37<00:00, 10.64it/s, loss=3.02]

Epoch [2/10], Loss: 3.0247



Epoch 3/10: 100%|██████████| 400/400 [00:37<00:00, 10.68it/s, loss=2.89]

Epoch [3/10], Loss: 2.8944



Epoch 4/10: 100%|██████████| 400/400 [00:37<00:00, 10.68it/s, loss=2.81]

Epoch [4/10], Loss: 2.8057



Epoch 5/10: 100%|██████████| 400/400 [00:37<00:00, 10.63it/s, loss=2.73]

Epoch [5/10], Loss: 2.7252



Epoch 6/10: 100%|██████████| 400/400 [00:37<00:00, 10.66it/s, loss=2.72]

Epoch [6/10], Loss: 2.7199



Epoch 7/10: 100%|██████████| 400/400 [00:37<00:00, 10.63it/s, loss=2.65]

Epoch [7/10], Loss: 2.6519



Epoch 8/10: 100%|██████████| 400/400 [00:37<00:00, 10.63it/s, loss=2.57]

Epoch [8/10], Loss: 2.5679



Epoch 9/10: 100%|██████████| 400/400 [00:37<00:00, 10.59it/s, loss=2.53]

Epoch [9/10], Loss: 2.5321



Epoch 10/10: 100%|██████████| 400/400 [00:37<00:00, 10.64it/s, loss=2.49]

Epoch [10/10], Loss: 2.4867



