<a href="https://colab.research.google.com/github/avyaymc/Convolutional-Visual-Prompts/blob/main/CVPlatent.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

#overall-imports

In [None]:

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import models
from torchvision.transforms import Compose, Resize, ToTensor, Normalize


In [None]:
import torch
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, random_split

In [None]:
!pip install tqdm


Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

#Finetune resnet18 on cifar data

In [None]:
transform = transforms.Compose(
    [transforms.RandomHorizontalFlip(),
     transforms.RandomCrop(32, padding=4),
     transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)

Files already downloaded and verified


In [None]:
train_size = int(0.9 * len(trainset))
val_size = len(trainset) - train_size
trainset, valset = random_split(trainset, [train_size, val_size])

trainloader = DataLoader(trainset, batch_size=100, shuffle=True, num_workers=2)
valloader = DataLoader(valset, batch_size=100, shuffle=False, num_workers=2)


In [None]:
import torch
import torchvision.models as models
import torch.nn as nn

def resnet18_cifar():
    model = models.resnet18(pretrained=True)
    model.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
    model.maxpool = nn.Identity()
    model.avgpool = nn.AdaptiveAvgPool2d((1, 1))
    model.fc = nn.Linear(512, 10)
    return model

resnet18_cifar10 = resnet18_cifar()

In [None]:
from tqdm import tqdm


device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
resnet18_cifar10.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(resnet18_cifar10.parameters(), lr=0.001, momentum=0.9, weight_decay=5e-4)

num_epochs = 30
for epoch in range(num_epochs):
    running_loss = 0.0
    pbar = tqdm(enumerate(trainloader, 0), total=len(trainloader), ncols=100, leave=True)

    for i, data in pbar:
        inputs, labels = data
        inputs, labels = inputs.to(device), labels.to(device)

        optimizer.zero_grad()

        outputs = resnet18_cifar10(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        pbar.set_description(f"Epoch {epoch+1}, Loss: {running_loss / (i + 1):.4f}")

print("Finished fine-tuning")


Epoch 1, Loss: 1.1840: 100%|██████████████████████████████████████| 450/450 [00:47<00:00,  9.41it/s]
Epoch 2, Loss: 0.6106: 100%|██████████████████████████████████████| 450/450 [00:40<00:00, 10.99it/s]
Epoch 3, Loss: 0.4621: 100%|██████████████████████████████████████| 450/450 [00:41<00:00, 10.89it/s]
Epoch 4, Loss: 0.3724: 100%|██████████████████████████████████████| 450/450 [00:40<00:00, 11.00it/s]
Epoch 5, Loss: 0.3157: 100%|██████████████████████████████████████| 450/450 [00:40<00:00, 10.98it/s]
Epoch 6, Loss: 0.2734: 100%|██████████████████████████████████████| 450/450 [00:41<00:00, 10.90it/s]
Epoch 7, Loss: 0.2414: 100%|██████████████████████████████████████| 450/450 [00:41<00:00, 10.90it/s]
Epoch 8, Loss: 0.2154: 100%|██████████████████████████████████████| 450/450 [00:41<00:00, 10.85it/s]
Epoch 9, Loss: 0.1948: 100%|██████████████████████████████████████| 450/450 [00:40<00:00, 11.17it/s]
Epoch 10, Loss: 0.1755: 100%|█████████████████████████████████████| 450/450 [00:40<00:00, 1

Finished fine-tuning





In [None]:
from torch.utils.data import DataLoader
import numpy as np

def evaluate(model, dataloader, device):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for data in dataloader:
            images, labels = data
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    return 100 * correct / total

In [None]:
val_accuracy = evaluate(resnet18_cifar10, valloader, device)
print(f"Validation accuracy: {val_accuracy:.2f}%")


Validation accuracy: 93.22%


#cifar-10-c

In [None]:
import numpy as np
from torchvision.transforms import ToTensor
from torch.utils.data import Dataset

class CIFAR10C(Dataset):
    def __init__(self, corruption_npy, labels_npy, transform=None):
        self.images = np.load(corruption_npy)
        self.labels = np.load(labels_npy)
        self.transform = transform

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image = self.images[idx]
        label = self.labels[idx]

        if self.transform:
            image = self.transform(image)

        return image, label


In [None]:
import os
from google.colab import drive
drive.mount('/content/drive')
data_folder = "/content/drive/MyDrive/CIFAR-10-C/"

labels_npy = f"{data_folder}/labels.npy"

# Load only the 'fog.npy' corruption file
corruption_file = "fog.npy"
corruption_npy = f"{data_folder}/{corruption_file}"
corruption_name = corruption_file[:-4]  # Remove the .npy extension
datasets = {corruption_name: CIFAR10C(corruption_npy, labels_npy, transform=ToTensor())}

# Save the fine-tuned model
#model_path = '/content/drive/MyDrive/resnet18_cifar10_finetuned.pth'
#torch.save(resnet18_cifar10.state_dict(), model_path)



Mounted at /content/drive


#checking with corrupted images

In [None]:
print("Evaluating on CIFAR-10-C:")
results = {}
batch_size = 100

for corruption_name, dataset in datasets.items():
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=2)
    accuracy = evaluate(resnet18_cifar10, dataloader, device)
    results[corruption_name] = accuracy
    print(f"Accuracy for {corruption_name}: {accuracy:.2f}%")

print("Evaluation complete.")

Evaluating on CIFAR-10-C:
Accuracy for fog: 62.30%
Evaluation complete.


In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
resnet18_cifar10.to(device)

results = {}
batch_size = 100

for corruption_name, dataset in datasets.items():
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=2)
    accuracy = evaluate(resnet18_cifar10, dataloader, device)
    results[corruption_name] = accuracy
    print(f"Accuracy for {corruption_name}: {accuracy:.2f}%")

print("Evaluation complete.")

Accuracy for fog: 62.30%
Evaluation complete.


#Convolutional Prompting at Latent Level


In [None]:
class ResNet18Latent(nn.Module):
    def __init__(self, original_model, device):
        super(ResNet18Latent, self).__init__()
        self.features = nn.Sequential(*list(original_model.children())[:-1]).to(device)

    def forward(self, x):
        x = self.features(x)
        return x.view(x.size(0), -1)



In [None]:
from google.colab import drive
import os

drive.mount('/content/drive')
model_save_path = "/content/drive/MyDrive/resnet18_cifar10_finetuned.pth"

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
def resnet18_cifar():
    model = models.resnet18(pretrained=True)
    model.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
    model.maxpool = nn.Identity()
    model.avgpool = nn.AdaptiveAvgPool2d((1, 1))
    model.fc = nn.Linear(512, 10)
    return model


resnet18_cifar10 = resnet18_cifar().to(device)

# Load the saved model parameters
resnet18_cifar10.load_state_dict(torch.load(model_save_path))
#resnet18_cifar10.load_state_dict(torch.load(model_save_path, map_location=torch.device('cpu')))

resnet18_latent = ResNet18Latent(resnet18_cifar10, device)


Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth
100%|██████████| 44.7M/44.7M [00:00<00:00, 190MB/s]


In [None]:
def apply_latent_prompt(latent_vectors, prompt_matrix):
    return torch.matmul(latent_vectors, prompt_matrix)


In [None]:
prompt_size = 512
v = torch.randn(prompt_size, prompt_size, requires_grad=True, device=device)


In [None]:
def evaluate_with_latent_prompt(model, latent_model, prompt_matrix, dataloader, device):
    model.eval()
    latent_model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for data in dataloader:
            images, labels = data
            images, labels = images.to(device), labels.to(device)
            latent_vectors = latent_model(images)
            prompted_latent_vectors = apply_latent_prompt(latent_vectors, prompt_matrix)
            outputs = model.fc(prompted_latent_vectors)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    return 100 * correct / total


In [None]:
results = {}
batch_size = 100

for corruption_name, dataset in datasets.items():
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=2)
    accuracy = evaluate_with_latent_prompt(resnet18_cifar10, resnet18_latent, v, dataloader, device)
    results[corruption_name] = accuracy
    print(f"Accuracy for {corruption_name}: {accuracy:.2f}%")

print("Evaluation complete.")


Accuracy for fog: 9.56%
Evaluation complete.


#simpleoptimize

In [None]:
class SimpleMLP(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(SimpleMLP, self).__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(hidden_dim, output_dim)

    def forward(self, x):
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        return x


In [None]:

input_dim = 512
hidden_dim = 128
output_dim = 1
self_supervised_model = SimpleMLP(input_dim, hidden_dim, output_dim)


In [None]:
class ResNetFeatures(nn.Module):
    def __init__(self, original_model):
        super(ResNetFeatures, self).__init__()
        self.features = nn.Sequential(*list(original_model.children())[:-1])

    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), -1)
        return x

resnet18_cifar10_features = ResNetFeatures(resnet18_cifar10)

In [None]:
class SelfSupervisedModel(nn.Module):
    def __init__(self, backbone, feature_dim):
        super(SelfSupervisedModel, self).__init__()
        self.backbone = backbone
        self.mlp = nn.Sequential(
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Linear(256, feature_dim)
        )

    def forward(self, x):
        x = self.backbone(x)
        x = self.mlp(x)
        return x


In [None]:
def contrastive_loss(z, y, tau=0.1, epsilon=1e-8):
    z = F.normalize(z, dim=1)
    cos_sim = torch.matmul(z, z.T) / (torch.norm(z, dim=1).unsqueeze(1) * torch.norm(z, dim=1).unsqueeze(0) + epsilon)

    exp_cos_sim = torch.exp(cos_sim / tau)
    exp_cos_sim_sum = exp_cos_sim.sum(1) - torch.diag(exp_cos_sim)
    loss = torch.log(exp_cos_sim_sum) - cos_sim.diagonal() / tau
    loss = loss.mean()

    return loss

In [None]:

import torch.nn.functional as F

def apply_conv_prompt(images, conv_kernel):
    return F.conv2d(images, conv_kernel, stride=1, padding=1)

In [None]:
def weights_init(m):
    if isinstance(m, nn.Linear):
        nn.init.xavier_normal_(m.weight)
        nn.init.zeros_(m.bias)

self_supervised_model.apply(weights_init)


for name, param in self_supervised_model.named_parameters():
    print(name, "Initial Parameter:", param.norm().item())

fc1.weight Initial Parameter: 14.339766502380371
fc1.bias Initial Parameter: 0.0
fc2.weight Initial Parameter: 1.4443987607955933
fc2.bias Initial Parameter: 0.0


In [None]:

def train_self_supervised_model(dataloader, model, self_supervised_model, conv_kernel, device, num_epochs, learning_rate):
    model.to(device)
    self_supervised_model.to(device)

    optimizer = optim.Adam(self_supervised_model.parameters(), lr=learning_rate)
    model.eval()

    for epoch in range(num_epochs):
        running_loss = 0.0
        for i, (inputs, _) in enumerate(dataloader, 0):
            inputs = inputs.to(device)


            # if torch.isnan(inputs).any() or torch.isinf(inputs).any():
            #     #print("NaN or Inf values found in the input images.")
            #     continue


            prompted_inputs = apply_conv_prompt(inputs, conv_kernel)


            with torch.no_grad():
                features = model(prompted_inputs)




            optimizer.zero_grad()
            z = self_supervised_model(features)
            loss = contrastive_loss(z, inputs)
            loss.backward()
            #print("Loss value:", loss.item())
            torch.nn.utils.clip_grad_norm_(self_supervised_model.parameters(), 1.0)
            print("Loss value:", loss.item())
            # for name, param in self_supervised_model.named_parameters():
            #     print(name, "Gradient:", param.grad.norm().item(), "Parameter:", param.norm().item())

            optimizer.step()

            running_loss += loss.item()

        print(f"Epoch {epoch+1}, Loss: {running_loss / (i+1)}")
    print("Finished Training")

In [None]:
conv_kernel = torch.nn.Parameter(torch.randn(3, 3, 3, 3) / 9, requires_grad=True)
feature_dim = 128
resnet18_latent = SelfSupervisedModel(resnet18_cifar10, feature_dim)

corrupted_dataloader = DataLoader(datasets['fog'], batch_size=batch_size, shuffle=True, num_workers=2)

num_epochs = 30
learning_rate = 0.001


conv_kernel = conv_kernel.to(device)
train_self_supervised_model(corrupted_dataloader, resnet18_cifar10_features, self_supervised_model, conv_kernel, device, num_epochs, learning_rate)



[1;30;43mStreaming output truncated to the last 5000 lines.[0m
Loss value: 3.892620086669922
Loss value: 3.904670238494873
Loss value: 3.89682674407959
Loss value: 3.920891046524048
Loss value: 3.8920202255249023
Loss value: 3.8920202255249023
Loss value: 3.89682674407959
Loss value: 3.895022392272949
Loss value: 3.892620086669922
Loss value: 3.904670238494873
Loss value: 3.892620086669922
Loss value: 3.895022392272949
Loss value: 3.895022392272949
Loss value: 3.891819953918457
Loss value: 3.89682674407959
Loss value: 3.9315319061279297
Loss value: 3.89682674407959
Loss value: 3.904670238494873
Loss value: 3.891819953918457
Loss value: 3.8920202255249023
Loss value: 3.893620729446411
Loss value: 3.89682674407959
Loss value: 3.895022392272949
Loss value: 3.908102035522461
Loss value: 3.90164852142334
Loss value: 3.8926198482513428
Loss value: 3.893620491027832
Loss value: 3.891819953918457
Loss value: 3.90164852142334
Loss value: 3.904670238494873
Loss value: 3.892620086669922
Loss va

#latentoptimize(memory issues for running)


In [None]:
class ResNetFeatures(torch.nn.Module):
    def __init__(self, model, layer_index, conv_kernel=None):
        super(ResNetFeatures, self).__init__()
        self.features = torch.nn.Sequential(*list(model.children())[:layer_index])
        self.conv_kernel = conv_kernel

    def forward(self, x):
        if self.conv_kernel is not None:
            x = F.conv2d(x, self.conv_kernel, stride=1, padding=1)
        x = self.features(x)
        return x

In [None]:
layer_index = 1
resnet18_cifar10_features = ResNetFeatures(resnet18_cifar10, layer_index)

In [None]:
layer_input_channels = resnet18_cifar10_features.features[layer_index - 1].out_channels
conv_kernel = torch.nn.Parameter(torch.randn(layer_input_channels, layer_input_channels, 3, 3) / 9, requires_grad=True)

In [None]:
def apply_conv_prompt(images, conv_kernel):
    return F.conv2d(images, conv_kernel, stride=1, padding=1)

In [None]:
with torch.no_grad():
    sample_input = torch.randn(1, 3, 32, 32).to(device)
    sample_output = resnet18_cifar10_features(sample_input)
    feature_shape = sample_output.shape[1:]
    feature_dim = torch.prod(torch.tensor(feature_shape)).item()
    print("Feature shape:", feature_shape)
    print("Feature dim:", feature_dim)


Feature shape: torch.Size([64, 32, 32])
Feature dim: 65536


In [None]:
input_dim = feature_dim
self_supervised_model = SimpleMLP(input_dim, hidden_dim, output_dim)


In [None]:
def train_self_supervised_model(dataloader, model, self_supervised_model, conv_kernel, device, num_epochs, learning_rate, layer_index):
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(list(self_supervised_model.parameters()) + list(conv_kernel.parameters()), lr=learning_rate)

    for epoch in range(num_epochs):
        running_loss = 0.0
        for i, data in enumerate(dataloader, 0):
            inputs, _ = data
            inputs = inputs.to(device)

            optimizer.zero_grad()


            prompted_inputs = F.conv2d(inputs, conv_kernel, stride=1, padding=1)


            with torch.no_grad():
                features = model(prompted_inputs)


            outputs = self_supervised_model(features.view(features.size(0), -1))
            loss = criterion(outputs, torch.arange(0, outputs.size(0), dtype=torch.long, device=device))

            loss.backward()
            optimizer.step()

            running_loss += loss.item()

        print(f"Epoch {epoch + 1}, Loss: {running_loss / (i + 1)}")

    print("Finished training self-supervised model")

In [None]:
class ConvKernel(torch.nn.Module):
    def __init__(self):
        super(ConvKernel, self).__init__()
        self.conv_kernel = torch.nn.Parameter(torch.randn(64, 3, 3, 3) / 9, requires_grad=True)

    def forward(self, x):
        return F.conv2d(x, self.conv_kernel, stride=1, padding=1)

conv_kernel = ConvKernel().to(device)
resnet18_cifar10_features = ResNetFeatures(resnet18_cifar10, layer_index, conv_kernel)
train_self_supervised_model(corrupted_dataloader, resnet18_cifar10_features, self_supervised_model, conv_kernel, device, num_epochs, learning_rate, layer_index)

#final

In [None]:
class ResNetFeatures(nn.Module):
    def __init__(self, original_model):
        super(ResNetFeatures, self).__init__()
        self.features = nn.Sequential(*list(original_model.children())[:-1])

    def forward(self, x, layer_index):
        for idx, layer in enumerate(self.features):
            x = layer(x)
            if idx == layer_index:
                break
        return x


In [None]:
class SelfSupervisedModel(nn.Module):
    def __init__(self, backbone, feature_dim):
        super(SelfSupervisedModel, self).__init__()
        self.backbone = backbone
        self.adaptive_pool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(feature_dim, feature_dim)

    def forward(self, x):
        x = self.backbone(x)
        x = self.adaptive_pool(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x


In [None]:
layer_index = 1
resnet18_cifar10_features = ResNetFeatures(resnet18_cifar10, layer_index)

In [None]:
layer_input_channels = resnet18_cifar10_features.features[layer_index - 1].out_channels
conv_kernel = torch.nn.Parameter(torch.randn(layer_input_channels, layer_input_channels, 3, 3) / 9, requires_grad=True)


In [None]:
def apply_conv_prompt(images, conv_kernel):
    return F.conv2d(images, conv_kernel, stride=1, padding=1)

In [None]:
with torch.no_grad():
    sample_input = torch.randn(1, 3, 32, 32).to(device)
    sample_output = resnet18_cifar10_features(sample_input)
    feature_shape = sample_output.shape[1:]
    feature_dim = torch.prod(torch.tensor(feature_shape)).item()
    print("Feature shape:", feature_shape)
    print("Feature dim:", feature_dim)

Feature shape: torch.Size([64, 32, 32])
Feature dim: 65536


In [None]:
input_dim = feature_dim
self_supervised_model = SimpleMLP(input_dim, hidden_dim, output_dim).to(device)


In [None]:
def train_self_supervised_model(dataloader, model, self_supervised_model, conv_kernel, device, num_epochs, learning_rate, layer_index):
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(list(self_supervised_model.parameters()) + list(conv_kernel.parameters()), lr=learning_rate)

    for epoch in range(num_epochs):
        running_loss = 0.0
        for i, data in enumerate(dataloader, 0):
            inputs, _ = data
            inputs = inputs.to(device)

            optimizer.zero_grad()

            # Apply the conv_kernel (CVP)
            prompted_inputs = F.conv2d(inputs, conv_kernel.weight, stride=1, padding=1)

            # Get the features from the ResNet-18 model
            with torch.no_grad():
                features = model(prompted_inputs)

            # Pass the features through the self-supervised model
            outputs = self_supervised_model(features.view(features.size(0), -1))
            loss = criterion(outputs, torch.arange(0, outputs.size(0), dtype=torch.long, device=device))

            loss.backward()
            optimizer.step()

            running_loss += loss.item()

        print(f'Epoch [{epoch + 1}/{num_epochs}], Loss: {running_loss / (i + 1)}')

    print("Finished training self-supervised model")

In [None]:
class ConvKernel(nn.Module):
    def __init__(self, kernel_size=3, in_channels=3, out_channels=3):
        super(ConvKernel, self).__init__()
        self.weight = nn.Parameter(torch.randn(out_channels, in_channels, kernel_size, kernel_size))

    def forward(self, x):
        x = x.view(-1, 3, 32, 32)  # Reshape the input tensor
        return F.conv2d(x, self.weight, stride=1, padding=1)



conv_kernel = ConvKernel().to(device)
self_supervised_model = SelfSupervisedModel(resnet18_cifar10_features, feature_dim).to(device)
resnet18_cifar10_features = ResNetFeatures(resnet18_cifar10, layer_index, conv_kernel_module=conv_kernel).to(device)

train_self_supervised_model(corrupted_dataloader, resnet18_cifar10_features, self_supervised_model, conv_kernel, device, num_epochs, learning_rate, layer_index)


In [None]:
def train_self_supervised_model(dataloader, model, self_supervised_model, conv_kernel_module, device, num_epochs, learning_rate, layer_index, accumulation_steps=4):
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(list(self_supervised_model.parameters()) + list(conv_kernel_module.parameters()), lr=learning_rate)

    for epoch in range(num_epochs):
        running_loss = 0.0
        for i, data in enumerate(dataloader, 0):
            inputs, labels = data
            inputs = inputs.to(device)
            labels = labels.to(device)  # Move the labels to the device

            # Apply the conv_kernel_module (CVP)
            prompted_inputs = conv_kernel_module(inputs)

            # Get the features from the ResNet-18 model
            features = model(prompted_inputs, layer_index)
            features = features.view(features.size(0), -1)

            # Forward pass in the self_supervised_model
            outputs = self_supervised_model(features)
            loss = criterion(outputs, labels)  # Use the original labels instead of the predicted ones

            # Backward pass and optimization
            loss.backward()

            if (i+1) % accumulation_steps == 0:
                optimizer.step()
                optimizer.zero_grad()

            running_loss += loss.item()

        # Print the average loss for this epoch
        print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss / (i+1)}")



In [None]:
batch_size = 25
accumulation_steps = 4
corrupted_dataloader = torch.utils.data.DataLoader(datasets['fog'], batch_size=batch_size, shuffle=True, num_workers=2)

class ConvKernel(nn.Module):
    def __init__(self, kernel_size, in_channels, out_channels):
        super(ConvKernel, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride=1, padding=1)

    def forward(self, x):
        return self.conv(x)

conv_kernel_module = ConvKernel(kernel_size=3, in_channels=3, out_channels=3).to(device)

resnet18_cifar10_features = ResNetFeatures(resnet18_cifar10).to(device)
train_self_supervised_model(corrupted_dataloader, resnet18_cifar10_features, self_supervised_model, conv_kernel_module, device, num_epochs, learning_rate, layer_index, accumulation_steps=accumulation_steps)

RuntimeError: ignored