In [62]:
from transformers import AutoModelForImageClassification
from torchvision.transforms import v2 as transformsv2
from torch.utils.data import Dataset, Subset
from tqdm.notebook import tqdm
from enum import Enum
from PIL import Image
import numpy as np
import pickle
import random
import torch
import os

In [63]:
if torch.cuda.is_available():
    device = torch.device("cuda")
    print("GPU is available and will be used:", torch.cuda.get_device_name(0))
else:
    device = torch.device("cpu")
    print("GPU is not available, using CPU.")

GPU is available and will be used: NVIDIA A100 80GB PCIe MIG 2g.20gb


In [64]:
def reset_seed(seed=42):
    torch.manual_seed(seed)
    random.seed(seed)
    np.random.seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed) 
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    os.environ['PYTHONHASHSEED'] = str(seed)

In [65]:
reset_seed(42)

In [66]:
class CustomCIFAR100(Dataset):
    def __init__(self, root, train=True, transform=None):
        self.root = root
        self.train = train
        self.transform = transform

        if self.train:
            self.data_file = os.path.join(self.root, 'cifar-100-python', 'train')
        else:
            self.data_file = os.path.join(self.root, 'cifar-100-python', 'test')

        with open(self.data_file, 'rb') as fo:
            dict = pickle.load(fo, encoding='bytes')
            self.data = dict[b'data']
            self.labels = dict[b'fine_labels']


    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        image = self.data[index].reshape(3, 32, 32).transpose(1, 2, 0)
        label = torch.as_tensor(self.labels[index])
        image = Image.fromarray(image.astype('uint8'), 'RGB')
        
        if self.transform:            
            torch.manual_seed(index)
            image = self.transform(image)

        return  image.to(device), label.to(device)
        

In [67]:
reset_seed(42)

In [68]:
transform = transformsv2.Compose([
    transformsv2.ToImage(),
    transformsv2.ToDtype(torch.float32, scale=True),
    transformsv2.Resize((224, 224), antialias=True),
    transformsv2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

augment_transform = transformsv2.Compose([ 
    transformsv2.ToImage(),
    transformsv2.ToDtype(torch.float32, scale=True),
    transformsv2.Resize(size=(224, 224), antialias=True),
    transformsv2.RandomHorizontalFlip(),
    transformsv2.RandomVerticalFlip(),
    transformsv2.RandomRotation(15),
    transformsv2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

train = CustomCIFAR100(root='./data/100', train=True, transform=transform)
train_aug = CustomCIFAR100(root='./data/100', train=True, transform=augment_transform)

In [70]:
from sklearn.model_selection import train_test_split
train_idx, validation_idx = train_test_split(np.arange(len(train)),
                                             test_size=0.2,
                                             random_state=42,
                                             shuffle=True,
                                             stratify=train.labels)

In [71]:
train = Subset(train, train_idx)
train_aug = Subset(train_aug, train_idx)



In [90]:
test = CustomCIFAR100(root='./data/100', train=False, transform=transform)
eval = CustomCIFAR100(root='./data/100', train=True, transform=transform)

In [74]:
eval = Subset(eval, validation_idx)

In [76]:
train_dataloader = torch.utils.data.DataLoader(train, batch_size=128, shuffle=False)
train_dataloader_aug = torch.utils.data.DataLoader(train_aug, batch_size=128, shuffle=False)

In [91]:
eval_dataloder = torch.utils.data.DataLoader(eval, batch_size=128, shuffle=False)
test_dataloader = torch.utils.data.DataLoader(test, batch_size=128, shuffle=False)

In [78]:
model = AutoModelForImageClassification.from_pretrained(
    "Ahmed9275/Vit-Cifar100",
    num_labels=100,
)

model.to(device)

ViTForImageClassification(
  (vit): ViTModel(
    (embeddings): ViTEmbeddings(
      (patch_embeddings): ViTPatchEmbeddings(
        (projection): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
      )
      (dropout): Dropout(p=0.0, inplace=False)
    )
    (encoder): ViTEncoder(
      (layer): ModuleList(
        (0-11): 12 x ViTLayer(
          (attention): ViTSdpaAttention(
            (attention): ViTSdpaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
            (output): ViTSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
          )
          (intermediate): ViTIntermediate(
            (dense): Linear(in_fe

In [79]:
model.eval()

ViTForImageClassification(
  (vit): ViTModel(
    (embeddings): ViTEmbeddings(
      (patch_embeddings): ViTPatchEmbeddings(
        (projection): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
      )
      (dropout): Dropout(p=0.0, inplace=False)
    )
    (encoder): ViTEncoder(
      (layer): ModuleList(
        (0-11): 12 x ViTLayer(
          (attention): ViTSdpaAttention(
            (attention): ViTSdpaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
            (output): ViTSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
          )
          (intermediate): ViTIntermediate(
            (dense): Linear(in_fe

In [80]:
def unpickle(file):
    with open(file, 'rb') as fo:
        dict = pickle.load(fo, encoding='bytes')
    return dict

def pickle_up(file, contents):
    with open(file, 'wb') as fo:
        pickle.dump(contents, fo, protocol=pickle.HIGHEST_PROTOCOL)

In [81]:
logits = []
logits_aug = []

In [82]:
def generate_logits(dataloder):

    logits_arr = []
    for batch in tqdm(dataloder):
        pixel_values, labels = batch
        with torch.no_grad():
            outputs = model(pixel_values)
            logits = outputs.logits
        logits_arr.append(logits.cpu().numpy())
    return logits_arr

In [83]:
def flatten_logits(logits_arr):
    logits_arr_flat = []
    for tensor in logits_arr:
        logits_arr_flat.extend(tensor)
    return logits_arr_flat

In [84]:
logits = generate_logits(train_dataloader)
logits_aug = generate_logits(train_dataloader_aug)

logits_eval = generate_logits(eval_dataloder)

  0%|          | 0/313 [00:00<?, ?it/s]

  0%|          | 0/313 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

In [92]:
logits_test = generate_logits(test_dataloader)

  0%|          | 0/79 [00:00<?, ?it/s]

In [85]:
logits_flat = flatten_logits(logits)
logits_aug_flat = flatten_logits(logits_aug)

logits_eval_flat = flatten_logits(logits_eval)

In [93]:
logits_test_flat = flatten_logits(logits_test)

In [86]:
data_file = unpickle("data/100/cifar-100-python/train")

In [87]:
data = {key: [value[i] for i in train_idx] for key, value in data_file.items() if key != b"batch_label"}    
eval_data = {key: [value[i] for i in validation_idx] for key, value in data_file.items() if key != b"batch_label"}    

In [88]:
data[b"logits"] = logits_flat
data[b"logits_aug"] = logits_aug_flat

eval_data[b"logits"] = logits_eval_flat

In [94]:
testing = unpickle("data/100/cifar-100-python/test")
testing[b"logits"] = logits_test_flat
pickle_up("data/100-logits/cifar-100-python/test", testing)

In [89]:
pickle_up("data/100-logits/cifar-100-python/train", data)
pickle_up("data/100-logits/cifar-100-python/eval", eval_data)

In [23]:
dataset_part = Enum('dataset_part', [('TRAIN', 1), ('TEST', 2), ('EVAL', 3)])

In [34]:
class CustomCIFAR100(Dataset):
    def __init__(self, root, dataset_part = dataset_part.TRAIN, transform=None):
        self.root = root
        self.dataset_part = dataset_part
        self.transform = transform
        

        self.data = []
        self.targets = []
        self.logits = []
        self.logits_aug = []


        if self.dataset_part == dataset_part.TRAIN:
            data_file = os.path.join(self.root, 'cifar-100-python', 'train')
        elif self.dataset_part == dataset_part.TEST:
            data_file = os.path.join(self.root, 'cifar-100-python', 'test')
        else:
            data_file = os.path.join(self.root, 'cifar-100-python', 'eval')

        with open(data_file, 'rb') as fo:
            dict = pickle.load(fo, encoding='bytes')
            self.data.append(dict[b'data'])
            self.targets.extend(dict[b'fine_labels'])
            self.logits.extend(dict[b'logits'])  
            self.logits_aug.extend(dict[b'logits_aug'])   
            
        self.data = np.concatenate(self.data, axis=0)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        image = self.data[index].reshape(3, 32, 32).transpose(1, 2, 0)
        label = self.targets[index]
        #logit = self.logits[index]
        
        image = Image.fromarray(image.astype('uint8'), 'RGB')
        
        if self.transform:
            logit = self.logits[index] if len(self.transform.extra_repr()) < 300 else self.logits_aug[index]
            torch.manual_seed(index)
            image = self.transform(image)

        logit = torch.tensor(logit, dtype=torch.float)

        return {
            'pixel_values': image,
            'labels': label,
            'logits': logit
        }
    
    def remove_entries(self, remove_list):
        self.data = np.delete(self.data, remove_list, axis=0)
        self.targets = np.delete(self.targets, remove_list, axis=0)
        self.logits = np.delete(self.logits, remove_list, axis=0)
        self.logits_aug = np.delete(self.logits_aug, remove_list, axis=0)
        
    @property
    def labels(self):
        return self.targets

In [35]:
train_aug = CustomCIFAR100(root='./data/100-logits', dataset_part=dataset_part.TRAIN, transform=augment_transform)
train = CustomCIFAR100(root='./data/100-logits', dataset_part=dataset_part.TRAIN, transform=transform)
train_combo = torch.utils.data.ConcatDataset([train, train_aug])

In [36]:
def check_acc(dataset):
    corr = []
    for val in tqdm(dataset, desc = "Progress for base train set"):
        if torch.topk(val["logits"], k=1).indices.numpy()[0] == val["labels"]:  corr.append(True)
    
    return(f"Accuracy for base train set: {len(corr)/len(dataset)}")  

In [37]:
print(check_acc(train))
print(check_acc(train_aug))
print(check_acc(train_combo))

Progress for base train set:   0%|          | 0/40000 [00:00<?, ?it/s]

Accuracy for base train set: 0.94035


Progress for base train set:   0%|          | 0/40000 [00:00<?, ?it/s]

Accuracy for base train set: 0.6315


Progress for base train set:   0%|          | 0/80000 [00:00<?, ?it/s]

Accuracy for base train set: 0.785925


In [38]:
rem_ls = []

for index, val in enumerate(train_aug):
    target_alt = torch.topk(val["logits"], k=1).indices.numpy()[0]
    target_act = torch.topk(train[index]["logits"], k=1).indices.numpy()[0]
    if target_alt != target_act:
        rem_ls.append(index)

In [39]:
train_aug.remove_entries(rem_ls)
train_combo = torch.utils.data.ConcatDataset([train, train_aug])

In [40]:
print(check_acc(train_aug))
print(check_acc(train_combo))

Progress for base train set:   0%|          | 0/25912 [00:00<?, ?it/s]

Accuracy for base train set: 0.9609833281877123


Progress for base train set:   0%|          | 0/65912 [00:00<?, ?it/s]

Accuracy for base train set: 0.9484615851438282
