In [1]:
from transformers import AutoModelForImageClassification
from torchvision.transforms import v2 as transformsv2
from torch.utils.data import Dataset
from tqdm.notebook import tqdm
from enum import Enum
from PIL import Image
import numpy as np
import pickle
import random
import torch
import base
import os

In [2]:
if torch.cuda.is_available():
    device = torch.device("cuda")
    print("GPU is available and will be used:", torch.cuda.get_device_name(0))
else:
    device = torch.device("cpu")
    print("GPU is not available, using CPU.")

GPU is available and will be used: NVIDIA A100 80GB PCIe MIG 2g.20gb


In [3]:
def reset_seed(seed=42):
    torch.manual_seed(seed)
    random.seed(seed)
    np.random.seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed) 
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    os.environ['PYTHONHASHSEED'] = str(seed)

In [4]:
reset_seed(42)
dataset_part = Enum('dataset_part', [('TRAIN', 1), ('TEST', 2), ('EVAL', 3)])

In [42]:
class CustomCIFAR10(Dataset):
    def __init__(self, root, batch=None, train=True, transform=None):
        self.root = root
        self.train = train
        self.transform = transform
        

        if self.train:
            self.data_file = os.path.join(self.root, 'cifar-10-batches-py', f'data_batch_{batch}')
        else:
            self.data_file = os.path.join(self.root, 'cifar-10-batches-py', 'test_batch')

        with open(self.data_file, 'rb') as fo:
            dict = pickle.load(fo, encoding='bytes')
            self.data = dict[b'data']
            self.labels = dict[b'labels']

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        image = self.data[index].reshape(3, 32, 32).transpose(1, 2, 0)
        label = torch.as_tensor(self.labels[index])
        image = Image.fromarray(image.astype('uint8'), 'RGB')
        
        if self.transform:            
            torch.manual_seed(index)
            image = self.transform(image)


        return  image.to(device), label.to(device)

In [6]:
reset_seed(42)

In [7]:
transform = transformsv2.Compose([
    transformsv2.ToImage(),
    transformsv2.ToDtype(torch.float32, scale=True),
    transformsv2.Resize((224, 224), antialias=True),
    transformsv2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

augment_transform = transformsv2.Compose([ 
    transformsv2.ToImage(),
    transformsv2.ToDtype(torch.float32, scale=True),
    transformsv2.Resize(size=(224, 224), antialias=True),
    transformsv2.RandomHorizontalFlip(),
    transformsv2.RandomVerticalFlip(),
    transformsv2.RandomRotation(15),
    transformsv2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

In [8]:
reset_seed(42)

In [9]:
model = AutoModelForImageClassification.from_pretrained(
    "aaraki/vit-base-patch16-224-in21k-finetuned-cifar10",
    num_labels=10,
)
model.to(device)

ViTForImageClassification(
  (vit): ViTModel(
    (embeddings): ViTEmbeddings(
      (patch_embeddings): ViTPatchEmbeddings(
        (projection): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
      )
      (dropout): Dropout(p=0.0, inplace=False)
    )
    (encoder): ViTEncoder(
      (layer): ModuleList(
        (0-11): 12 x ViTLayer(
          (attention): ViTSdpaAttention(
            (attention): ViTSdpaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
            (output): ViTSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
          )
          (intermediate): ViTIntermediate(
            (dense): Linear(in_fe

In [10]:
model.eval()

ViTForImageClassification(
  (vit): ViTModel(
    (embeddings): ViTEmbeddings(
      (patch_embeddings): ViTPatchEmbeddings(
        (projection): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
      )
      (dropout): Dropout(p=0.0, inplace=False)
    )
    (encoder): ViTEncoder(
      (layer): ModuleList(
        (0-11): 12 x ViTLayer(
          (attention): ViTSdpaAttention(
            (attention): ViTSdpaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
            (output): ViTSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
          )
          (intermediate): ViTIntermediate(
            (dense): Linear(in_fe

In [11]:
def unpickle(file):
    with open(file, 'rb') as fo:
        dict = pickle.load(fo, encoding='bytes')
    return dict

def pickle_up(file, contents):
    with open(file, 'wb') as fo:
        pickle.dump(contents, fo, protocol=pickle.HIGHEST_PROTOCOL)

In [38]:
def generate_logits(dataloder):
    logits_arr = []
    for batch in tqdm(dataloder):
        pixel_values, labels = batch
        with torch.no_grad():
            outputs = model(pixel_values)
            logits = outputs.logits
        logits_arr.append(logits.cpu().numpy())
    return logits_arr

In [40]:
def flatten_logits(logits_arr):
    logits_arr_flat = []
    for tensor in logits_arr:
        logits_arr_flat.extend(tensor)
    return logits_arr_flat

In [45]:
testing = unpickle("data/10/cifar-10-batches-py/test_batch")
test_data = CustomCIFAR10(root='./data/10', train=False, transform=transform)
test_dataloader = torch.utils.data.DataLoader(test_data, batch_size=128, shuffle=False)

logits_test = generate_logits(test_dataloader)
logits_test = flatten_logits(logits_test)
testing[b"logits"] = logits_test
pickle_up("data/10-logits/cifar-10-batches-py/test", testing)



evaluating = unpickle("data/10/cifar-10-batches-py/data_batch_5")
eval_data = CustomCIFAR10(root='./data/10', train=True, batch=5, transform=transform)
eval_dataload = torch.utils.data.DataLoader(eval_data, batch_size=128, shuffle=False)

logits_eval = generate_logits(eval_dataload)
logits_eval = flatten_logits(logits_eval)
evaluating[b"logits"] = logits_eval
pickle_up("data/10-logits/cifar-10-batches-py/eval", evaluating)

  0%|          | 0/79 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

In [23]:
#Pro Cifar10

reset_seed(42)
for index in range(1,5):

    data = unpickle(f"data/10/cifar-10-batches-py/data_batch_{index}")

    train = CustomCIFAR10(root='./data/10', batch=index, train=True, transform=transform)
    train_augmented = CustomCIFAR10(root='./data/10', batch=index, train=True, transform=augment_transform)
    train_dataloader = torch.utils.data.DataLoader(train, batch_size=64, shuffle=False)
    train_dataloader_augmented = torch.utils.data.DataLoader(train_augmented, batch_size=64, shuffle=False)

    logits_arr = []
    logits_arr_aug = []

    for batch in tqdm(train_dataloader, desc = f"Progress for file {index}."): 
        pixel_values, labels = batch
        with torch.no_grad():
            outputs = model(pixel_values)
            logits = outputs.logits
        logits_arr.append(logits.cpu().numpy())

    for batch in tqdm(train_dataloader_augmented, desc = f"Progress for file {index} with augmentation."): 
        pixel_values, labels = batch
        with torch.no_grad():
            outputs = model(pixel_values)
            logits = outputs.logits
        logits_arr_aug.append(logits.cpu().numpy())

    
    logits_arr_flat = []
    logits_arr_aug_flat = []

    for tensor in logits_arr:
        logits_arr_flat.extend(tensor)

    for tensor in logits_arr_aug:
        logits_arr_aug_flat.extend(tensor)

    data[b"logits"] = logits_arr_flat
    data[b"logits_aug"] = logits_arr_aug_flat
    pickle_up(f"data/10-logits/cifar-10-batches-py/train_batch_{index}",data)

Progress for file 1.:   0%|          | 0/157 [00:00<?, ?it/s]

Progress for file 1 with augmentation.:   0%|          | 0/157 [00:00<?, ?it/s]

Progress for file 2.:   0%|          | 0/157 [00:00<?, ?it/s]

Progress for file 2 with augmentation.:   0%|          | 0/157 [00:00<?, ?it/s]

Progress for file 3.:   0%|          | 0/157 [00:00<?, ?it/s]

Progress for file 3 with augmentation.:   0%|          | 0/157 [00:00<?, ?it/s]

Progress for file 4.:   0%|          | 0/157 [00:00<?, ?it/s]

Progress for file 4 with augmentation.:   0%|          | 0/157 [00:00<?, ?it/s]

In [30]:
class CustomCIFAR10(Dataset):
    def __init__(self, root, dataset_part = dataset_part.TRAIN, transform=None):
        self.root = root
        self.dataset_part = dataset_part
        self.transform = transform
        self.data = []
        self.targets = []
        self.logits = []
        self.logits_aug = []

        if self.dataset_part == dataset_part.TRAIN:
            data_file = os.path.join(self.root, 'cifar-10-batches-py', 'train')
            for i in range(1, 5):
                 data_file = os.path.join(self.root, 'cifar-10-batches-py', f'train_batch_{i}')
                 with open(data_file, 'rb') as fo:
                     dict = pickle.load(fo, encoding='bytes')
                     self.data.append(dict[b'data'])
                     self.targets.extend(dict[b'labels'])
                     self.logits.extend(dict[b'logits'])
                     self.logits_aug.extend(dict[b'logits_aug'])   

        elif self.dataset_part == dataset_part.TEST:
            data_file = os.path.join(self.root, 'cifar-10-batches-py', 'test')
            with open(data_file, 'rb') as fo:
                dict = pickle.load(fo, encoding='bytes')
                self.data.append(dict[b'data'])
                self.targets.extend(dict[b'labels'])

        
        else:
            data_file = os.path.join(self.root, 'cifar-10-batches-py', 'eval')
            with open(data_file, 'rb') as fo:
                dict = pickle.load(fo, encoding='bytes')
                self.data.append(dict[b'data'])
                self.targets.extend(dict[b'labels'])  
        
        self.data = np.concatenate(self.data, axis=0)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        image = self.data[index].reshape(3, 32, 32).transpose(1, 2, 0)
        image = Image.fromarray(image.astype('uint8'), 'RGB')
        label = self.targets[index]
        #logit = self.logits[index]

        if self.transform:
            logit = self.logits[index] if len(self.transform.extra_repr()) < 300 else self.logits_aug[index]
            torch.manual_seed(index%10000)
            image = self.transform(image)

        logit = torch.tensor(logit, dtype=torch.float)
        
            

        return {
            'pixel_values': image,
            'labels': label,
            'logits': logit,
        }
    
    def remove_entries(self, remove_list):
        self.data = np.delete(self.data, remove_list, axis=0)
        self.targets = np.delete(self.targets, remove_list, axis=0)
        self.logits = np.delete(self.logits, remove_list, axis=0)
        self.logits_aug = np.delete(self.logits_aug, remove_list, axis=0)
    
    @property
    def labels(self):
        return self.targets


In [31]:
train_aug = CustomCIFAR10(root='./data/10-logits', dataset_part=dataset_part.TRAIN, transform=augment_transform)
train = CustomCIFAR10(root='./data/10-logits', dataset_part=dataset_part.TRAIN, transform=transform)
train_combo = torch.utils.data.ConcatDataset([train, train_aug])

In [32]:
def check_acc(dataset):
    corr = []
    for val in tqdm(dataset, desc = "Progress for base train set"):
        if torch.topk(val["logits"], k=1).indices.numpy()[0] == val["labels"]:  corr.append(True)
    
    return(f"Accuracy for base train set: {len(corr)/len(dataset)}")  

In [33]:
print(check_acc(train))
print(check_acc(train_aug))
print(check_acc(train_combo))

Progress for base train set:   0%|          | 0/40000 [00:00<?, ?it/s]

Accuracy for base train set: 0.954925


Progress for base train set:   0%|          | 0/40000 [00:00<?, ?it/s]

Accuracy for base train set: 0.686


Progress for base train set:   0%|          | 0/80000 [00:00<?, ?it/s]

Accuracy for base train set: 0.8204625


In [34]:
rem_ls = []

for index, val in enumerate(train_aug):
    target_alt = torch.topk(val["logits"], k=1).indices.numpy()[0]
    target_act = torch.topk(train[index]["logits"], k=1).indices.numpy()[0]
    if target_alt != target_act:
        rem_ls.append(index)

In [35]:
train_aug.remove_entries(rem_ls)
train_combo = torch.utils.data.ConcatDataset([train, train_aug])

In [36]:
print(check_acc(train_aug))
print(check_acc(train_combo))

Progress for base train set:   0%|          | 0/28176 [00:00<?, ?it/s]

Accuracy for base train set: 0.9614565587734242


Progress for base train set:   0%|          | 0/68176 [00:00<?, ?it/s]

Accuracy for base train set: 0.9576243839474302
