In [2]:
# Set up
import sys
import numpy as np
import pandas as pd
from tqdm.autonotebook import tqdm

import torch
import torch.nn as nn
from torch.nn import functional as F
from transformers import ViTFeatureExtractor, ViTForImageClassification

import warnings
warnings.filterwarnings('ignore')

sys.path.append('..')

dataset_path = '../Datasets/CIFAR10'
img_path = dataset_path + '/images'
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [3]:
# Set data owner id of interest
data_owner_id = 'A'

In [4]:
def seed_everything(seed=20):
    """set seed for all"""
    import os
    import torch
    import random
    import numpy as np
    os.environ['PYTHONHASHSEED'] = str(seed)
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
seed_everything()

In [5]:
def get_vit_model(device='cpu'):
    # download vision transformer model 
    vit_feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224-in21k')
    vit_model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224', output_hidden_states=True).to(device)

    # freeze the pre-trained model
    vit_model.eval()
    for param in vit_model.parameters():
        param.requires_grad = False
    return vit_feature_extractor, vit_model

In [21]:
# Training hyperparameters
num_epochs = 10
learning_rate = 0.001

In [6]:
# Load data owner dataset
data_owner_dataset = pd.read_excel(dataset_path + '/CIFAR10dataOwnerInfo.xlsx', sheet_name=data_owner_id)
data_owner_dataset.image = [f'{img_path}/{image}' for image in data_owner_dataset.image]

# Create data owner's model 
num_classes = data_owner_dataset.label_name.nunique()
model = nn.Sequential(
    nn.LazyBatchNorm1d(),
    nn.LazyLinear(128),
    nn.GELU(),
    nn.LazyLinear(num_classes)
).to(device)
vit_feature_extractor, vit_model = get_vit_model(device)

# Initialize loss function (criterion) and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

In [7]:
# Label encoding and get images and labels as lists
def label_enc(data_owner_dataset):
    from sklearn.preprocessing import LabelEncoder
    le = LabelEncoder().fit(data_owner_dataset.label_name)
    label2id = {k:v for k, v in zip(le.classes_, le.transform(le.classes_))}
    id2label = {v:k for k, v in label2id.items()}
    labels = le.transform(data_owner_dataset.label_name)
    images = data_owner_dataset.image.tolist()
    return images, labels, label2id, id2label

images, labels, label2id, id2label = label_enc(data_owner_dataset)

In [8]:
from core.ai.dataset import get_loader
import albumentations as A
train_transform = A.Compose([
    A.Resize(256, 256),
    A.RandomCrop(224, 224),
    A.Flip(p=0.5),
    A.RandomBrightnessContrast(brightness_limit=0.2, contrast_limit=0.2, p=0.5)
])
data_loader = get_loader(images, labels, vit_feature_extractor, train_transform,
                         pre_trained_model=vit_model, device=device)


In [19]:
# Test sample input for the model
sample_batch = next(iter(data_loader))
sample_images, sample_labels = sample_batch
logits = model(sample_images.to(device))
print(model)
print('Batch image shape:', list(sample_images.shape))
print('Batch label shape:', list(sample_labels.shape))
print('Model output shape:', list(logits.shape))

Sequential(
  (0): BatchNorm1d(768, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (1): Linear(in_features=768, out_features=128, bias=True)
  (2): GELU()
  (3): Linear(in_features=128, out_features=6, bias=True)
)
Batch image shape: [32, 768]
Batch label shape: [32]
Model output shape: [32, 6]


In [22]:
model.train()
for epoch in range(num_epochs):
    epoch_loss_list = []
    epoch_acc_sum = [0, 0]
    for batch_images, batch_labels in tqdm(data_loader):
        logits = model(batch_images.to(device))
        batch_labels = batch_labels.long().to(device)
        optimizer.zero_grad()
        loss = criterion(logits, batch_labels)
        loss.backward()
        optimizer.step()
        
        epoch_loss_list.append(loss.item())
        epoch_acc_sum[0] += (logits.argmax(1) == batch_labels).sum().item()
        epoch_acc_sum[1] += len(batch_labels)
    
        print(f'[ {epoch:2d}:{num_epochs} ]\tloss={np.mean(epoch_loss_list):.3f}, \
          acc={epoch_acc_sum[0]}/{epoch_acc_sum[1]}={epoch_acc_sum[0]/epoch_acc_sum[1]:.3f}')

  0%|          | 0/422 [00:00<?, ?it/s]

0 1.2761881440454186 [tensor(6747, device='cuda:0'), 13500]


  0%|          | 0/422 [00:00<?, ?it/s]

1 1.1783721616482847 [tensor(7404, device='cuda:0'), 13500]


  0%|          | 0/422 [00:00<?, ?it/s]

2 1.1486803274866528 [tensor(7557, device='cuda:0'), 13500]


  0%|          | 0/422 [00:00<?, ?it/s]

3 1.1271912181546904 [tensor(7712, device='cuda:0'), 13500]


  0%|          | 0/422 [00:00<?, ?it/s]

4 1.1119315944859203 [tensor(7774, device='cuda:0'), 13500]


  0%|          | 0/422 [00:00<?, ?it/s]

5 1.0891381286049342 [tensor(7804, device='cuda:0'), 13500]


  0%|          | 0/422 [00:00<?, ?it/s]

6 1.074612484300306 [tensor(7908, device='cuda:0'), 13500]


  0%|          | 0/422 [00:00<?, ?it/s]