In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, repeat
from einops.layers.torch import Rearrange
from accelerate import Accelerator
from torchvision.models import resnet50
from torchvision.datasets import CIFAR10
from torch.utils.data import DataLoader, Dataset
from torchvision.transforms import Compose, ToTensor, Normalize, RandomHorizontalFlip, v2, Resize
from sklearn.model_selection import train_test_split
from tqdm import tqdm
from accelerate import notebook_launcher
import os
import timm
os.environ["CUDA_VISIBLE_DEVICES"] = '0'

class lcnet_ft(nn.Module):
    def __init__(self, num_classes=10, embed_dim=1280):
        super(lcnet_ft, self).__init__()

        self.embed_dim = embed_dim
        
        # Load and freeze lcnet
        self.encoder = timm.create_model('timm/lcnet_075.ra2_in1k', pretrained=True, num_classes=num_classes)
        # for param in self.encoder.parameters():
        #     param.requires_grad = False

        # # unfreeze the last layer
        # for param in self.encoder.classifier.parameters():
        #     param.requires_grad = True

        # self.linear = nn.Linear(embed_dim, embed_dim)
        
    def forward(self, x):
        logits = self.encoder(x)  # Assuming [batch_size, embed_dim]

        # linear projection
        # logits = self.linear(logits)
        
        return logits


class mnasnet_ft(nn.Module):
    def __init__(self, num_classes=10, embed_dim=1280):
        super(mnasnet_ft, self).__init__()

        self.embed_dim = embed_dim
        
        # Load and freeze lcnet
        self.encoder = timm.create_model('timm/mnasnet_small.lamb_in1k', pretrained=True, num_classes=num_classes)
        # for param in self.encoder.parameters():
        #     param.requires_grad = False

        # # unfreeze the last layer
        # for param in self.encoder.classifier.parameters():
        #     param.requires_grad = True

        # self.linear = nn.Linear(embed_dim, embed_dim)
        
    def forward(self, x):
        logits = self.encoder(x)  # Assuming [batch_size, embed_dim]

        # linear projection
        # logits = self.linear(logits)
        
        return logits


class repghostnet_ft(nn.Module):
    def __init__(self, num_classes=10, embed_dim=1280):
        super(repghostnet_ft, self).__init__()

        self.embed_dim = embed_dim
        
        # Load and freeze lcnet
        self.encoder = timm.create_model('timm/repghostnet_050.in1k', pretrained=True, num_classes=num_classes)
        # for param in self.encoder.parameters():
        #     param.requires_grad = False


        # # unfreeze the last layer
        # for param in self.encoder.classifier.parameters():
        #     param.requires_grad = True

        # self.linear = nn.Linear(embed_dim, embed_dim)
        
    def forward(self, x):
        logits = self.encoder(x)  # Assuming [batch_size, embed_dim]

        # linear projection
        # logits = self.linear(logits)
        
        return logits
    

class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim, dropout = 0.):
        super().__init__()
        self.net = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, dim),
            nn.Dropout(dropout)
        )

    def forward(self, x):
        return self.net(x)

class Attention(nn.Module):
    def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
        super().__init__()
        inner_dim = dim_head *  heads
        project_out = not (heads == 1 and dim_head == dim)

        self.heads = heads
        self.scale = dim_head ** -0.5

        self.norm = nn.LayerNorm(dim)

        self.attend = nn.Softmax(dim = -1)
        self.dropout = nn.Dropout(dropout)

        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)

        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim),
            nn.Dropout(dropout)
        ) if project_out else nn.Identity()

    def forward(self, x):
        x = self.norm(x)

        qkv = self.to_qkv(x).chunk(3, dim = -1)
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)

        dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale

        attn = self.attend(dots)
        attn = self.dropout(attn)

        out = torch.matmul(attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')
        return self.to_out(out)

class Transformer(nn.Module):
    def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout),
                FeedForward(dim, mlp_dim, dropout = dropout)
            ]))

    def forward(self, x):
        for attn, ff in self.layers:
            x = attn(x) + x
            x = ff(x) + x

        return self.norm(x)

class ResNetViTEnsembleWithTransformer(nn.Module):
    def __init__(self, num_classes=10, embed_dim=1280, heads=8, dim_feedforward_hidden=2048, num_layers=3, dropout=0.1):
        super(ResNetViTEnsembleWithTransformer, self).__init__()

        self.embed_dim = embed_dim
        
        # Load and freeze 
        self.lcnet = lcnet_ft(num_classes=num_classes, embed_dim=embed_dim)
        self.mnasnet = mnasnet_ft(num_classes=num_classes, embed_dim=embed_dim)
        self.repghostnet = repghostnet_ft(num_classes=num_classes, embed_dim=embed_dim)

        self.lcnet.encoder.classifier = nn.Identity()
        self.mnasnet.encoder.classifier = nn.Identity()
        self.repghostnet.encoder.classifier = nn.Identity()

        # Class token
        self.class_token = nn.Parameter(torch.randn(1, 1, embed_dim))
        nn.init.xavier_uniform_(self.class_token)
        
        # Transformer projection
        # encoder_layer = nn.TransformerEncoderLayer(d_model=embed_dim, nhead=heads, dim_feedforward=dim_feedforward_hidden, norm_first = False)    
        # self.transformer_projection = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)

        self.transformer_projection = Transformer(embed_dim, num_layers, heads, dim_feedforward_hidden, dim_feedforward_hidden, dropout=dropout)

        self.softmax = nn.Softmax(dim=1)

        self.classifier = nn.Sequential(
            nn.Linear(embed_dim, num_classes),
            # nn.ReLU(),
            # # nn.Linear(4096, 2048),
            # # nn.ReLU(),
            # nn.Linear(512, 512),
            # nn.ReLU(),
            # nn.Linear(512, num_classes)
        )
    
        
    def forward(self, x):
        # Get the features from the backbone
        lcnet_features = self.lcnet.encoder(x)
        mnasnet_features = self.mnasnet.encoder(x)
        repghostnet_features = self.repghostnet.encoder(x)

        # Concatenate class token with features
        batch_size = x.size(0)
        class_tokens = self.class_token.expand(batch_size, -1, -1)
        combined_features = torch.cat((class_tokens, lcnet_features.unsqueeze(1), mnasnet_features.unsqueeze(1), repghostnet_features.unsqueeze(1)), dim=1)

        # Pass through transformer projection
        transformer_output = self.transformer_projection(combined_features)

        # Use class token for classification
        class_token_final = transformer_output[:, 0, :] 

        # class_token_final = transformer_output.mean(dim = 1)

        # Final classification
        logits = self.classifier(class_token_final)
        return logits


# Assuming CIFAR10 - adapt transforms for your dataset
transform = Compose([
    Resize((224,224)),
    RandomHorizontalFlip(),
    # v2.ColorJitter(brightness=(0.5,1.5),contrast=(1),saturation=(0.5,1.5),hue=(-0.1,0.1)),
    ToTensor(),
    Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.247, 0.243, 0.261]),
])

# Dataset and DataLoader setup
class FTDataset(Dataset):
    def __init__(self, images, labels, transform=None):
        self.images = images
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image = self.images[idx]
        label = self.labels[idx]
        
        if self.transform:
            image = self.transform(image)
        
        return image, label

def train_model(train_loader, test_loader):
    accelerator = Accelerator(mixed_precision="fp16")
    model = ResNetViTEnsembleWithTransformer()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)
    loss_fn = nn.CrossEntropyLoss()

    model, optimizer, train_loader, test_loader = accelerator.prepare(model, optimizer, train_loader, test_loader)

    for epoch in range(30):  # Adjust the number of epochs if needed
        model.train()
        total_loss = 0
        for x, labels in tqdm(train_loader, desc=f"Epoch {epoch + 1}"):
            optimizer.zero_grad()
            outputs = model(x)
            loss = loss_fn(outputs, labels)
            accelerator.backward(loss)
            optimizer.step()
            total_loss += loss.item()
        avg_loss = total_loss / len(train_loader)
        print(f"Epoch {epoch + 1}, Loss: {avg_loss:.4f}")

    # Optional test phase at the end of training
    model.eval()
    test_loss = 0
    with torch.no_grad():
        for x, labels in test_loader:
            outputs = model(x)
            loss = loss_fn(outputs, labels)
            test_loss += loss.item()
    avg_test_loss = test_loss / len(test_loader)
    print(f"Test Loss: {avg_test_loss:.4f}")

    # Save the final model
    model_path = "/home/chen/EECE570/new_models/ensemble_model/model_final_slctoken_unfrozen.pt"
    unencapsulated_model = accelerator.unwrap_model(model)
    accelerator.save(unencapsulated_model.state_dict(), model_path)


def main():
    root_dir = './data'
    batch_size = 128
    transform = Compose([
        Resize((224,224)),
        RandomHorizontalFlip(),
        ToTensor(),
        Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.247, 0.243, 0.261]),
    ])
    
    # Load full dataset
    full_dataset = CIFAR10(root=root_dir, train=True, download=True, transform=None)
    images = [image for image, _ in full_dataset]
    labels = [label for _, label in full_dataset]

    # Stratified train-test split
    train_images, test_images, train_labels, test_labels = train_test_split(
        images, labels, test_size=0.2, stratify=labels, random_state=42
    )

    # Creating datasets
    train_dataset = FTDataset(train_images, train_labels, transform=transform)
    test_dataset = FTDataset(test_images, test_labels, transform=transform)

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=8, pin_memory=True)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=8, pin_memory=True)

    # Call train_model or any other training function here
    train_model(train_loader, test_loader)

notebook_launcher(main, num_processes=1)

Launching training on one GPU.
Files already downloaded and verified


Epoch 1: 100%|██████████| 313/313 [01:02<00:00,  5.03it/s]


Epoch 1, Loss: 0.5310


Epoch 2: 100%|██████████| 313/313 [01:03<00:00,  4.96it/s]


Epoch 2, Loss: 0.1992


Epoch 3: 100%|██████████| 313/313 [01:03<00:00,  4.95it/s]


Epoch 3, Loss: 0.1306


Epoch 4: 100%|██████████| 313/313 [01:03<00:00,  4.96it/s]


Epoch 4, Loss: 0.0889


Epoch 5: 100%|██████████| 313/313 [01:03<00:00,  4.94it/s]


Epoch 5, Loss: 0.0739


Epoch 6: 100%|██████████| 313/313 [01:03<00:00,  4.96it/s]


Epoch 6, Loss: 0.0547


Epoch 7: 100%|██████████| 313/313 [01:03<00:00,  4.95it/s]


Epoch 7, Loss: 0.0511


Epoch 8: 100%|██████████| 313/313 [01:03<00:00,  4.96it/s]


Epoch 8, Loss: 0.0465


Epoch 9: 100%|██████████| 313/313 [01:03<00:00,  4.95it/s]


Epoch 9, Loss: 0.0449


Epoch 10: 100%|██████████| 313/313 [01:03<00:00,  4.95it/s]


Epoch 10, Loss: 0.0388


Epoch 11: 100%|██████████| 313/313 [01:03<00:00,  4.94it/s]


Epoch 11, Loss: 0.0377


Epoch 12: 100%|██████████| 313/313 [01:03<00:00,  4.94it/s]


Epoch 12, Loss: 0.0443


Epoch 13: 100%|██████████| 313/313 [01:03<00:00,  4.96it/s]


Epoch 13, Loss: 0.0484


Epoch 14: 100%|██████████| 313/313 [01:03<00:00,  4.95it/s]


Epoch 14, Loss: 0.0498


Epoch 15: 100%|██████████| 313/313 [01:03<00:00,  4.96it/s]


Epoch 15, Loss: 0.0548


Epoch 16: 100%|██████████| 313/313 [01:03<00:00,  4.95it/s]


Epoch 16, Loss: 0.0425


Epoch 17: 100%|██████████| 313/313 [01:03<00:00,  4.96it/s]


Epoch 17, Loss: 0.0367


Epoch 18: 100%|██████████| 313/313 [01:03<00:00,  4.94it/s]


Epoch 18, Loss: 0.0457


Epoch 19: 100%|██████████| 313/313 [01:03<00:00,  4.96it/s]


Epoch 19, Loss: 0.0288


Epoch 20: 100%|██████████| 313/313 [01:03<00:00,  4.95it/s]


Epoch 20, Loss: 0.0421


Epoch 21: 100%|██████████| 313/313 [01:03<00:00,  4.95it/s]


Epoch 21, Loss: 0.0355


Epoch 22: 100%|██████████| 313/313 [01:03<00:00,  4.96it/s]


Epoch 22, Loss: 0.0314


Epoch 23: 100%|██████████| 313/313 [01:03<00:00,  4.95it/s]


Epoch 23, Loss: 0.0215


Epoch 24: 100%|██████████| 313/313 [01:03<00:00,  4.96it/s]


Epoch 24, Loss: 0.0232


Epoch 25: 100%|██████████| 313/313 [01:03<00:00,  4.95it/s]


Epoch 25, Loss: 0.0236


Epoch 26: 100%|██████████| 313/313 [01:03<00:00,  4.96it/s]


Epoch 26, Loss: 0.0238


Epoch 27: 100%|██████████| 313/313 [01:03<00:00,  4.95it/s]


Epoch 27, Loss: 0.0459


Epoch 28: 100%|██████████| 313/313 [01:03<00:00,  4.96it/s]


Epoch 28, Loss: 0.0471


Epoch 29: 100%|██████████| 313/313 [01:03<00:00,  4.95it/s]


Epoch 29, Loss: 0.0429


Epoch 30: 100%|██████████| 313/313 [01:03<00:00,  4.94it/s]

Epoch 30, Loss: 0.0275





Test Loss: nan


In [5]:
# testing on testset
model = ResNetViTEnsembleWithTransformer()
model.load_state_dict(torch.load("/home/chen/EECE570/new_models/ensemble_model/model_final_slctoken_unfrozen.pt"))
model.cuda()

full_dataset = CIFAR10(root='./data', train=False, download=True, transform=None)
images = [image for image, _ in full_dataset]
labels = [label for _, label in full_dataset]


testset = FTDataset(images, labels, transform=transform)
testloader = DataLoader(testset, batch_size=64, shuffle=True, num_workers=8, pin_memory=True)


model.eval()
correct = 0
total = 0
for x, labels in tqdm(testloader):
    x, labels = x.cuda(), labels.cuda()
    outputs = model(x)
    predicted = torch.argmax(outputs, 1)
    total += labels.size(0)
    correct += (predicted == labels).sum().item()
print(f"Accuracy: {correct / total:.4f}")


Files already downloaded and verified


100%|██████████| 157/157 [00:05<00:00, 27.46it/s]

Accuracy: 0.9284





: 