In [10]:
import numpy
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt 
import random
from PIL import Image
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR10 

numpy.random.seed(42)
torch.manual_seed(42)

device = 'cuda' if torch.cuda.is_available() else 'cpu'
device

'cuda'

In [11]:
"""
def download_data(test, transform, batch_size, shuffle, num_workers): 
    if test: 
        test_set = CIFAR10(root="./cifar10", train=False, download=True, transform=transform)
        test_loader = DataLoader(test_set, batch_size, shuffle=shuffle, num_workers=num_workers)
    else: 
        train_set = CIFAR10(root="./cifar10", train=True, download=True, transform=transform)
        train_loader = DataLoader(train_set, batch_size, shuffle=shuffle, num_workers=num_workers)
        
    return train_loader, test_loader
"""

'\ndef download_data(test, transform, batch_size, shuffle, num_workers): \n    if test: \n        test_set = CIFAR10(root="./cifar10", train=False, download=True, transform=transform)\n        test_loader = DataLoader(test_set, batch_size, shuffle=shuffle, num_workers=num_workers)\n    else: \n        train_set = CIFAR10(root="./cifar10", train=True, download=True, transform=transform)\n        train_loader = DataLoader(train_set, batch_size, shuffle=shuffle, num_workers=num_workers)\n        \n    return train_loader, test_loader\n'

In [12]:
"""
def transform_data(): 
    transform = transforms.Compose(
    [
        transforms.ToTensor(), 
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])
    
    train_data = download_data(test=False, transform=transform, batch_size=4, shuffle=True, num_workers=2)
    test_data = download_data(test=True, transform=transform, batch_size=4, shuffle=False, num_workers=2)
    
    return train_data, test_data
"""

'\ndef transform_data(): \n    transform = transforms.Compose(\n    [\n        transforms.ToTensor(), \n        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))\n    ])\n    \n    train_data = download_data(test=False, transform=transform, batch_size=4, shuffle=True, num_workers=2)\n    test_data = download_data(test=True, transform=transform, batch_size=4, shuffle=False, num_workers=2)\n    \n    return train_data, test_data\n'

In [13]:
# transform_data()

In [14]:
"""
def show_image(classes, img): 
    img = img / 2 + 0.5
    numpy_img = img.numpy()
    plt.imshow(numpy.transpose(numpy_img, (1, 2, 0)))
    plt.show()
    
data_loop = iter(train_data)
images, labels = next(data_loop)
"""

'\ndef show_image(classes, img): \n    img = img / 2 + 0.5\n    numpy_img = img.numpy()\n    plt.imshow(numpy.transpose(numpy_img, (1, 2, 0)))\n    plt.show()\n    \ndata_loop = iter(train_data)\nimages, labels = next(data_loop)\n'

In [15]:
transform = transforms.Compose([
    transforms.ToTensor(), 
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

train_set = CIFAR10(root="./cifar10", train=True, download=True, transform=transform)
train_loader = DataLoader(train_set, batch_size=4, shuffle=True, num_workers=2)    

test_set = CIFAR10(root="./cifar10", train=False, download=True, transform=transform)
test_loader = DataLoader(test_set, batch_size=4, shuffle=False, num_workers=2)

train_loader, test_loader

Files already downloaded and verified
Files already downloaded and verified


(<torch.utils.data.dataloader.DataLoader at 0x7d94b3356e00>,
 <torch.utils.data.dataloader.DataLoader at 0x7d94b3357b80>)

In [16]:
"""
def show_image(image, classes): 
    img = image / 2 + 0.5 
    np_img = img.numpy()
    plt.imshow(numpy.transpose(np_img, (1, 2, 0)))
    plt.show()
    

data_iter = iter(train_loader)
images, labels = next(data_iter)
batch_size = 4

classes = {'plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck'}

show_image(torchvision.utils.make_grid(images), classes)
print(' '.join(f"{classes[labels[j]]:5s}" for j in range(batch_size)))
"""

'\ndef show_image(image, classes): \n    img = image / 2 + 0.5 \n    np_img = img.numpy()\n    plt.imshow(numpy.transpose(np_img, (1, 2, 0)))\n    plt.show()\n    \n\ndata_iter = iter(train_loader)\nimages, labels = next(data_iter)\nbatch_size = 4\n\nclasses = {\'plane\', \'car\', \'bird\', \'cat\', \'deer\', \'dog\', \'frog\', \'horse\', \'ship\', \'truck\'}\n\nshow_image(torchvision.utils.make_grid(images), classes)\nprint(\' \'.join(f"{classes[labels[j]]:5s}" for j in range(batch_size)))\n'

In [17]:
class MSA(nn.Module): 
    def __init__(self, d, n_heads=4): 
        super(MSA, self).__init__()
        self.d = d
        self.n_heads = n_heads 
        
        assert d % n_heads == 0
        
        d_head = int(d / n_heads)
        self.q_mapping = nn.ModuleList([nn.Linear(d_head, d_head) for _ in range(self.n_heads)])
        self.k_mapping = nn.ModuleList([nn.Linear(d_head, d_head) for _ in range(self.n_heads)])
        self.v_mapping = nn.ModuleList([nn.Linear(d_head, d_head) for _ in range(self.n_heads)])
        self.d_head = d_head
        self.softmax = nn.Softmax(dim=-1)
    
    def forward(self, sequence): 
        result = []
        for sequence in sequences: 
            seq_result = []
            for head in range(self.n_heads): 
                q_mapping = self.q_mapping[head]
                k_mapping = self.k_mapping[head]
                v_mapping = self.v_mapping[head]
                
                seq = sequence[:, head * self.d_head: (head+1) * self.d_head]
                q, k, v = q_mapping(seq), k_mapping(seq), v_mapping(seq)
                
                attention = self.softmax(q @ K.T / (self.d_head ** 0.5))
                seq_result.append(attention @ v)
            result.append(torch.hstack(seq_result))
        return torch.cat([torch.unsqueeze(r, dim=0) for r in result])

In [18]:
class MLP(nn.Module): 
    def __init__(self, in_features, hidden_features=None, out_features=None, activation_layer=nn.GELU, drop=0.): 
        super(MLP, self).__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        self.layer1 = nn.Linear(in_features, hidden_features)
        self.activation = activation_layer()
        self.layer2 = nn.Linear(hidden_features, out_features)
        self.drop = nn.Dropout(drop)
    
    def forward(self, x): 
        x = self.layer1(x)
        x = self.activation(x)
        x = self.drop(x)
        x = self.layer2(x)
        x = self.drop(x)
        return x

In [19]:
class SE(nn.Module): 
    def __init__(self, channels, reduction): 
        super(SE, self).__init__()
        mid_channels = channels // reduction 
        self.squeeze = nn.AdaptiveMaxPool2dAd(output_size=1)
        self.excitation = nn.Sequential(
            nn.Conv2d(channels, channels // reduction, kernel_size=1)
            nn.SiLU(), 
            nn.Conv2d(channels // reduction, channels, kernel_size=1)
            nn.Sigmoid()
        )
        
    def forward(self, x): 
        s = self.squeeze(x)
        e = self.excitation(s)
        return x * e

In [None]:
def conv_block(in_channels, out_channels, kernel_size=3, stride=1, padding=1, groups=1, bias=False, bn=True, act=True): 
    return nn.Sequential(
        nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride, padding=padding, groups=groups, bias=bias), 
        nn.BatchNorm2d(out_channels) if bn else nn.Identity()
        nn.SiLU() if act else nn.Identity()
    )

In [20]:
class FusedMBConv(nn.Module): 
    def __init__(self, in_channels, out_channels, expansion, kernel_size=3, stride=1, bn=True, act=True, r=24, dropout=0.1): 
        super(FusedMBConv, self).__init__()
        self.skip_connection = (in_channels == out_channels) and (stride == 1) 
        padding = (kernel_size - 1) // 2
        expanded = expansion * in_channels
        
        self.expand_conv = conv_block(in_channels, expanded, kernel_size=3, stride=stride, padding=1)
        self.point_conv = conv_block(expanded, out_channels, kernel_size=1, padding=0, act=False)
        
        if expansion == 1: 
            self.point_pw = nn.Identity()
            self.expand_pw = conv_block(in_channels, out_channels, kernel_size=3, stride=stride, padding=1)
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x): 
        res = x 
        x = self.expand_pw(x)
        x = self.reduce_pw(x)
        if self.skip_connection: 
            x = self.dropout(x)
            x = x + res
        return x

In [23]:
class FeatExtract(nn.Module): 
    def __init__(self, dim, keep_dim=False): 
        super(FeatExtract, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1, groups=dim, bias=False)
            nn.GELU(), 
            SE(dim, dim),
            nn.Conv2d(dim, dim, kernel_size=1, stride=1, padding=0, bias=False)
        )
        
        if not keep_dim: 
            self.pool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.keep_dim = keep_dim 
        
    def forward(self, x): 
        x = x.contiguous()
        x = x + self.conv(x)
        if not self.keep_dim: 
            x = self.pool(x) 
        return x

SyntaxError: invalid syntax. Perhaps you forgot a comma? (3576103443.py, line 5)