In [105]:
import torch
import torch.nn as nn

class VanillaSequencerBlock(nn.Module):
    def __init__(self, input_size, hidden_size, mlp_input_size, mlp_output_size):
        super(VanillaSequencerBlock, self).__init__()

        
        self.normal_layer = nn.LayerNorm(input_size)

        
        self.bilstm = nn.LSTM(input_size, hidden_size, bidirectional=True, batch_first=True)

       
        self.normalization_after_merge = nn.LayerNorm(input_size)

        
        self.channel_mlp = nn.Linear(mlp_input_size, mlp_output_size)

    def forward(self, x):
        if isinstance(x, list):
            x = torch.stack(x)
        batch_size,channel, height= x.size()
       
        outputs=[]
        for index in range(batch_size):
            
            
            y=x[index]
            y = self.normal_layer(y)
            
            
            output, _ = self.bilstm(y)
            
            
            #y = y + output #following the paper instructions
            
           
            y = self.normalization_after_merge(y)
            
            
            channel_output = self.channel_mlp(y)
            
            
            #y = y + channel_output  #following the paper instructions
            
            outputs.append(y)
       
        return outputs


In [3]:
import torch
import torch.nn as nn

class PatchEmbedding(nn.Module):
    def __init__(self, image_size, patch_size, in_channels, embed_dim):
        super(PatchEmbedding, self).__init__()
        self.image_size = image_size
        self.patch_size = patch_size
        self.in_channels = in_channels
        self.embed_dim = embed_dim

        
        self.num_patches = (image_size // patch_size) ** 2

        self.projection = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size)

    def forward(self, x):
        
        

       
        x = self.projection(x).flatten(2).transpose(1, 2)  

        return x
    def output_dimension(self):
        return self.embed_dim * self.num_patches



In [4]:
class PWLinearLayer(nn.Module):
    def __init__(self, in_features, out_features):
        super(PWLinearLayer, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.linear = nn.Linear(in_features, out_features)

    def forward(self, input_list):
        output_list = []
        for input_tensor in input_list:
            output_tensor = self.linear(input_tensor)
            output_list.append(output_tensor)
        stacked_output = torch.stack(output_list, dim=0)
        return stacked_output


In [5]:
class PatchMerging(nn.Module):
    def __init__(self, in_channels, out_channels, scale_factor):
        super(PatchMerging, self).__init__()
        self.scale_factor = scale_factor
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)

    def forward(self, patch_list):
        
        x = torch.cat(patch_list, dim=1).permute(0,1)
        x=x.unsqueeze(0)
        x=x.permute(2,1,0)
       
        x = self.conv(x)
        
        
        
        x = nn.functional.interpolate(x, scale_factor=self.scale_factor, mode='nearest')
        
        return x


In [111]:

import torch.nn.init as init
class VanillaSequencerBlockModel(nn.Module):
    def __init__(self, num_classes, in_channels):
        super(VanillaSequencerBlockModel, self).__init__()

        self.num_classes = num_classes
        

       
        self.patch_embedding_1 = PatchEmbedding(16, 8, in_channels, 128)#  patch embedding with an 8x8 kernel size for each patch
        self.ln_1 = nn.LayerNorm(128)

        
        self.sequencer_block_1 =  nn.Sequential(
            VanillaSequencerBlock(16, 48, 16, 128),
            VanillaSequencerBlock(96, 96, 96, 192),
            VanillaSequencerBlock(192, 192, 192,384),
            VanillaSequencerBlock(384, 192, 384,384)
        )


        
        self.patch_merging=PatchMerging(49152,128,2)

      
        self.sequencer_block_2 =  nn.Sequential(
            VanillaSequencerBlock(384, 192, 3,384),
            VanillaSequencerBlock(384, 192, 384,384),
            VanillaSequencerBlock(384, 192, 384,384)
        )

       
        self.pw_linear_1 = PWLinearLayer( 2,384)
        
        
        self.sequencer_block_3 =  nn.Sequential(
            VanillaSequencerBlock(384, 192, 384,384),
            VanillaSequencerBlock(384, 192, 384,384),
            VanillaSequencerBlock(384, 192, 384,384),
            VanillaSequencerBlock(384, 192, 384,384),
            VanillaSequencerBlock(384, 192, 384,384),
            VanillaSequencerBlock(384, 192, 384,384),
            VanillaSequencerBlock(384, 192, 384,384),
            VanillaSequencerBlock(384, 192, 384,384),
            VanillaSequencerBlock(384, 192, 384,384),
            VanillaSequencerBlock(384, 192, 384,384),
            VanillaSequencerBlock(384, 192, 384,384),
            VanillaSequencerBlock(384, 192, 384,384),
            VanillaSequencerBlock(384, 192, 384,384),
            VanillaSequencerBlock(384, 192, 384,384)
            
           
        )

        self.pw_linear_2 = PWLinearLayer(384, 384)

       
        self.sequencer_block_4 =  nn.Sequential(
            VanillaSequencerBlock(384, 192, 384,384),
            VanillaSequencerBlock(384, 192, 384,384),
            VanillaSequencerBlock(384, 192, 384,384)
        )
        self.pw_linear_3 = PWLinearLayer(384, 384)


       
        self.ln_2 = nn.LayerNorm(384)

        
        self.global_avg_pool = nn.AdaptiveAvgPool2d((1, 384))

        
        self.fc = nn.Linear(384, num_classes)
        
        
        

    def forward(self, x):
       
       
        x = x.to(torch.float32).permute(0,3,1,2)#convert input tensor to float32 + compute permutation
        x = self.patch_embedding_1(x)
        x = self.ln_1(x)
        
        
        x=x.permute(0,2,1)
        x = self.sequencer_block_1(x)

       
        x = self.patch_merging(x)
        
        

        
        x = self.pw_linear_1(x)
        
       
        x = self.pw_linear_3(x)

       
        x = self.ln_2(x)
       
        x = self.global_avg_pool(x)
        
        x = self.fc(x)
     

        return x



In [7]:
from torchvision import datasets
import numpy as np
from torch.utils.data import Dataset

class CustomCIFAR2(Dataset):
    def __init__(self, root, train=True, transform=None, target_transform=None, download=False):
        super(CustomCIFAR2, self).__init__()
        self.cifar10 = datasets.CIFAR10(root, train=train, transform=transform, target_transform=target_transform, download=download)
        
        
        self.keep_classes = [0, 1, 2, 3, 4]  
        self.data, self.targets = self.filter_classes()

    def filter_classes(self):
        mask = np.isin(self.cifar10.targets, self.keep_classes)
        data = [self.cifar10.data[i] for i, include in enumerate(mask) if include]
        targets = [self.cifar10.targets[i] for i, include in enumerate(mask) if include]
        return data, targets

    def __getitem__(self, index):
        img, target = self.data[index], self.targets[index]
        return img, target

    def __len__(self):
        return len(self.data)


In [112]:
import torch.optim as optim

import torchvision.transforms as transforms
from torch.utils.data import DataLoader
image_size = 32
patch_size = 8
model = VanillaSequencerBlockModel(num_classes=2, in_channels=3)
criterion = nn.BCELoss()
optimizer = optim.Adam(model.parameters(), lr=0.01)


batch_size = 128

num_epochs = 10

transform = transforms.Compose([
    transforms.RandomHorizontalFlip(),   # Randomly flip the image horizontally
    transforms.RandomRotation(15),      # Randomly rotate the image by up to 15 degrees
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2),
    transforms.ColorJitter(brightness=0.8, contrast=0.8, saturation=0.8, hue=0.8),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2),
    transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.4),# Adjust brightness, contrast, saturation, and hue
    transforms.RandomResizedCrop(16),
    transforms.RandomResizedCrop(4),# Randomly crop and resize the image to 224x224
    transforms.RandomAffine(degrees=0, translate=(0.1, 0.1)),
    transforms.RandomAffine(degrees=4, translate=(0.4, 0.1)),# Randomly translate the image
    transforms.ToTensor(),              # Convert the image to a tensor
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # Normalize the image

])



#train_dataset = torchvision.datasets.CIFAR10(root='/Users/stellafazioli/Downloads/cifar-10-batches-py', train=True, transform=transform)
custom_dataset = CustomCIFAR2(root='/Users/stellafazioli/Downloads/cifar-10-batches-py', train=True, transform=transform, download=True)
data_loader = DataLoader(custom_dataset, batch_size=batch_size, shuffle=True)

Files already downloaded and verified


In [110]:
# Set your device (CPU or GPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# beginning of training phase
for epoch in range(num_epochs):
    total_correct = 0
    total_samples = 0
    running_loss = 0.0  #  loss for this epoch (but first 200 batch)

    for i, data in enumerate(data_loader, 0):
        if i==78:
            continue
        inputs, labels = data
        inputs, labels = inputs.to(device), labels.to(device)

        print(inputs.shape)
        optimizer.zero_grad()
        
        outputs = model(inputs)
        
        outputs=outputs.view(128,2)
        
       
        predicted_probabilities = torch.sigmoid(outputs)
       
        loss = criterion(predicted_probabilities[:,0], labels.float())
        
        loss.backward()
        optimizer.step()

        
        running_loss += loss.item()
        predicted_probabilities = outputs.argmax(dim=1)
        correct = (predicted_probabilities == labels).sum().item()
        total_correct += correct
        total_samples += labels.size(0)
        print(predicted_probabilities.shape)
        batch_accuracy = (correct / labels.size(0)) * 100.0
        print(f"Epoch [{epoch+1}/{num_epochs}] Batch [{i}/{len(data_loader)}] Loss: {loss.item():.4f} Accuracy: {batch_accuracy:.2f}%")
        if i % 200 == 199:  # Print every 200 mini-batches
            print(f"Epoch [{epoch+1}/{num_epochs}] Batch [{i}/{len(data_loader)}] Loss: {loss.item():.4f} Accuracy: {batch_accuracy:.2f}%")
            print(f"[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 200:.3f}")
            running_loss = 0.0

    epoch_accuracy = (total_correct / total_samples) * 100.0
    print(f"Epoch [{epoch+1}/{num_epochs}] Accuracy: {epoch_accuracy:.2f}%")

print("Finished Training")


torch.Size([128, 32, 32, 3])


RuntimeError: Given normalized_shape=[96], expected input with shape [*, 96], but got input of size[128, 16]

In [None]:
import pandas
from sklearn import model_selection
from sklearn.linear_model import LogisticRegression
import pickle

model_pkl_file = "vanillasequencer_model384.pkl"

with open(model_pkl_file, 'wb') as file:  
    pickle.dump(model, file)

NameError: name 'model' is not defined