In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from accelerate import Accelerator
from torchvision.models import resnet50
from torchvision.datasets import CIFAR10
from torch.utils.data import DataLoader, Dataset
from torchvision.transforms import Compose, ToTensor, Normalize, RandomHorizontalFlip, v2, Resize
from tqdm import tqdm
from accelerate import notebook_launcher
from sklearn.model_selection import train_test_split
import os
import timm
import time
os.environ["CUDA_VISIBLE_DEVICES"] = '0'

class lcnet_ft(nn.Module):
    def __init__(self, num_classes=10, embed_dim=1280):
        super(lcnet_ft, self).__init__()

        self.embed_dim = embed_dim
        
        # Load and freeze lcnet
        self.encoder = timm.create_model('timm/lcnet_075.ra2_in1k', pretrained=True, num_classes=num_classes)
        for param in self.encoder.parameters():
            param.requires_grad = False

        # unfreeze the last layer
        for param in self.encoder.classifier.parameters():
            param.requires_grad = True
        
    def forward(self, x):
        logits = self.encoder(x)  # Assuming [batch_size, embed_dim]
        
        return logits


class mnasnet_ft(nn.Module):
    def __init__(self, num_classes=10, embed_dim=1280):
        super(mnasnet_ft, self).__init__()

        self.embed_dim = embed_dim
        
        # Load and freeze lcnet
        self.encoder = timm.create_model('timm/mnasnet_small.lamb_in1k', pretrained=True, num_classes=num_classes)
        for param in self.encoder.parameters():
            param.requires_grad = False

        # unfreeze the last layer
        for param in self.encoder.classifier.parameters():
            param.requires_grad = True
        
    def forward(self, x):
        logits = self.encoder(x)  # Assuming [batch_size, embed_dim]
        
        return logits


class repghostnet_ft(nn.Module):
    def __init__(self, num_classes=10, embed_dim=1280):
        super(repghostnet_ft, self).__init__()

        self.embed_dim = embed_dim
        
        # Load and freeze lcnet
        self.encoder = timm.create_model('timm/repghostnet_050.in1k', pretrained=True, num_classes=num_classes)
        for param in self.encoder.parameters():
            param.requires_grad = False

        # unfreeze the last layer
        for param in self.encoder.classifier.parameters():
            param.requires_grad = True
        
    def forward(self, x):
        logits = self.encoder(x)  # Assuming [batch_size, embed_dim]
        
        return logits

class ResNetViTEnsembleWithTransformer(nn.Module):
    def __init__(self, num_classes=10, embed_dim=1280):
        super(ResNetViTEnsembleWithTransformer, self).__init__()

        self.embed_dim = embed_dim
        
        # Load and freeze 
        self.lcnet = lcnet_ft(num_classes=num_classes, embed_dim=embed_dim)
        self.mnasnet = mnasnet_ft(num_classes=num_classes, embed_dim=embed_dim)
        self.repghostnet = repghostnet_ft(num_classes=num_classes, embed_dim=embed_dim)

        self.lcnet.encoder.classifier = nn.Identity()
        self.mnasnet.encoder.classifier = nn.Identity()
        self.repghostnet.encoder.classifier = nn.Identity()

        self.classifier = nn.Linear(embed_dim, num_classes)
    
        
    def forward(self, x):
        # Get the features from the backbone
        lcnet_features = self.lcnet.encoder(x)
        mnasnet_features = self.mnasnet.encoder(x)
        repghostnet_features = self.repghostnet.encoder(x)

        # Final classification
        feature_concat = torch.cat((lcnet_features.unsqueeze(1), mnasnet_features.unsqueeze(1), repghostnet_features.unsqueeze(1)), dim=1)

        # take mean of the features
        feature_concat = torch.mean(feature_concat, dim=1).squeeze(1)
        
        # Final classification

        logits = self.classifier(feature_concat)

        return logits


# Assuming CIFAR10 - adapt transforms for your dataset
transform = Compose([
    Resize((224,224)),
    RandomHorizontalFlip(),
    # v2.ColorJitter(brightness=(0.5,1.5),contrast=(1),saturation=(0.5,1.5),hue=(-0.1,0.1)),
    ToTensor(),
    Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.247, 0.243, 0.261]),
])

# Dataset and DataLoader setup
class FTDataset(Dataset):
    def __init__(self, images, labels, transform=None):
        self.images = images
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image = self.images[idx]
        label = self.labels[idx]
        
        if self.transform:
            image = self.transform(image)
        
        return image, label

def train_model(train_loader, test_loader):
    accelerator = Accelerator(mixed_precision="fp16")
    model = ResNetViTEnsembleWithTransformer()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)
    loss_fn = nn.CrossEntropyLoss()

    model, optimizer, train_loader, test_loader = accelerator.prepare(model, optimizer, train_loader, test_loader)
    
    # time start
    start = time.time()
    for epoch in range(10):  # Adjust the number of epochs if needed
        model.train()
        total_loss = 0
        for x, labels in tqdm(train_loader, desc=f"Epoch {epoch + 1}"):
            optimizer.zero_grad()
            outputs = model(x)
            loss = loss_fn(outputs, labels)
            accelerator.backward(loss)
            optimizer.step()
            total_loss += loss.item()
        avg_loss = total_loss / len(train_loader)
        print(f"Epoch {epoch + 1}, Training Loss: {avg_loss:.4f}")
    # time end
    end = time.time()
    print("Time taken: ", end-start)

    # Optional test phase at the end of training
    model.eval()
    test_loss = 0
    with torch.no_grad():
        start = time.time()
        for x, labels in test_loader:
            outputs = model(x)
            loss = loss_fn(outputs, labels)
            test_loss += loss.item()
        end = time.time()
        print("Test time taken: ", (end-start)/len(test_loader))
    avg_test_loss = test_loss / len(test_loader)
    print(f"Epoch {epoch + 1}, Validation Loss: {avg_test_loss:.4f}")

    # Save the final model
    model_path = "/home/chen/EECE570/new_models/ensemble_model/model_final_average.pt"
    unencapsulated_model = accelerator.unwrap_model(model)
    accelerator.save(unencapsulated_model.state_dict(), model_path)


def main():
    root_dir = './data'
    batch_size = 128
    transform = Compose([
        Resize((224,224)),
        RandomHorizontalFlip(),
        ToTensor(),
        Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.247, 0.243, 0.261]),
    ])
    
    # Load full dataset
    full_dataset = CIFAR10(root=root_dir, train=True, download=True, transform=None)
    images = [image for image, _ in full_dataset]
    labels = [label for _, label in full_dataset]

    # Stratified train-test split
    train_images, test_images, train_labels, test_labels = train_test_split(
        images, labels, test_size=0.2, stratify=labels, random_state=42
    )

    # Creating datasets
    train_dataset = FTDataset(train_images, train_labels, transform=transform)
    test_dataset = FTDataset(test_images, test_labels, transform=transform)

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=8, pin_memory=True)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=8, pin_memory=True)

    # Call train_model or any other training function here
    train_model(train_loader, test_loader)

notebook_launcher(main, num_processes=1)



Launching training on one GPU.
Files already downloaded and verified


Epoch 1: 100%|██████████| 313/313 [00:17<00:00, 17.87it/s]


Epoch 1, Training Loss: 1.8240


Epoch 2: 100%|██████████| 313/313 [00:15<00:00, 19.73it/s]


Epoch 2, Training Loss: 1.2476


Epoch 3: 100%|██████████| 313/313 [00:15<00:00, 19.82it/s]


Epoch 3, Training Loss: 0.9899


Epoch 4: 100%|██████████| 313/313 [00:15<00:00, 19.77it/s]


Epoch 4, Training Loss: 0.8533


Epoch 5: 100%|██████████| 313/313 [00:15<00:00, 19.81it/s]


Epoch 5, Training Loss: 0.7681


Epoch 6: 100%|██████████| 313/313 [00:15<00:00, 19.66it/s]


Epoch 6, Training Loss: 0.7104


Epoch 7: 100%|██████████| 313/313 [00:15<00:00, 19.74it/s]


Epoch 7, Training Loss: 0.6693


Epoch 8: 100%|██████████| 313/313 [00:15<00:00, 19.65it/s]


Epoch 8, Training Loss: 0.6362


Epoch 9: 100%|██████████| 313/313 [00:15<00:00, 19.69it/s]


Epoch 9, Training Loss: 0.6091


Epoch 10: 100%|██████████| 313/313 [00:15<00:00, 19.67it/s]

Epoch 10, Training Loss: 0.5882
Time taken:  160.34771490097046





Test time taken:  0.03756343865696388
Epoch 10, Validation Loss: 0.5772


In [None]:
# testing on testset
model = ResNetViTEnsembleWithTransformer()
model.load_state_dict(torch.load("/home/chen/EECE570/new_models/ensemble_model/model_final_average.pt"))
model.cuda()

full_dataset = CIFAR10(root='./data', train=False, download=True, transform=None)
images = [image for image, _ in full_dataset]
labels = [label for _, label in full_dataset]


testset = FTDataset(images, labels, transform=transform)
testloader = DataLoader(testset, batch_size=64, shuffle=True, num_workers=8, pin_memory=True)


model.eval()
total = 0
correct = 0
false = 0
false_positive = 0
false_negative = 0
with torch.no_grad():
    for x, labels in tqdm(testloader):
        x, labels = x.cuda(), labels.cuda()
        outputs = model(x)
        predictions = torch.argmax(outputs, 1)
        total += labels.size(0)
        correct += (predictions == labels).sum().item()
        false_positive += ((predictions == 1) & (labels == 0)).sum().item()
        false_negative += ((predictions == 0) & (labels == 1)).sum().item()


# output accuracy, recell, precision, f1 score
accuracy = correct / total
recall = correct / (correct + false_negative)
precision = correct / (correct + false_positive)
f1_score = 2 * (precision * recall) / (precision + recall)

print("Accuracy: {:.4f}".format(accuracy))
print("Recall: {:.4f}".format(recall))
print("Precision: {:.4f}".format(precision))
print("F1 Score: {:.4f}".format(f1_score))


Files already downloaded and verified


100%|██████████| 157/157 [00:04<00:00, 36.59it/s]

Accuracy: 0.8185
Recall: 0.9978
Precision: 0.9985
F1 Score: 0.9982



