In [None]:
%pip install transformers[torch] huggingface_hub datasets evaluate torchvision ipywidgets

In [1]:
from torch.utils.data import DataLoader, Subset, ConcatDataset
from transformers import AutoModelForImageClassification
from sklearn.model_selection import train_test_split
import numpy as np
import torch

import base

In [2]:
if torch.cuda.is_available():
    device = torch.device("cuda")
    print("GPU is available and will be used:", torch.cuda.get_device_name(0))
else:
    device = torch.device("cpu")
    print("GPU is not available, using CPU.")

GPU is available and will be used: NVIDIA A100 80GB PCIe MIG 2g.20gb


In [3]:
base.reset_seed(42)

In [4]:
transform = base.base_transforms()
augment_transform = base.aug_transforms()

train = base.CustomCIFAR100(root='./data/100', train=True, transform=transform)
train_aug = base.CustomCIFAR100(root='./data/100', train=True, transform=augment_transform)


test = base.CustomCIFAR100(root='./data/100', train=False, transform=transform)
eval = base.CustomCIFAR100(root='./data/100', train=True, transform=transform)

In [5]:
train_idx, validation_idx = train_test_split(np.arange(len(train)),
                                             test_size=0.2,
                                             random_state=42,
                                             shuffle=True,
                                             stratify=train.labels)

In [6]:
train = Subset(train, train_idx)
train_aug = Subset(train_aug, train_idx)
eval = Subset(eval, validation_idx)

In [7]:
train_dataloader = DataLoader(train, batch_size=128, shuffle=False)
train_dataloader_aug = DataLoader(train_aug, batch_size=128, shuffle=False)

In [8]:
eval_dataloder = DataLoader(eval, batch_size=128, shuffle=False)
test_dataloader = DataLoader(test, batch_size=128, shuffle=False)

In [9]:
model = AutoModelForImageClassification.from_pretrained(
    "Ahmed9275/Vit-Cifar100",
    num_labels=100,
)

model.to(device)

ViTForImageClassification(
  (vit): ViTModel(
    (embeddings): ViTEmbeddings(
      (patch_embeddings): ViTPatchEmbeddings(
        (projection): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
      )
      (dropout): Dropout(p=0.0, inplace=False)
    )
    (encoder): ViTEncoder(
      (layer): ModuleList(
        (0-11): 12 x ViTLayer(
          (attention): ViTSdpaAttention(
            (attention): ViTSdpaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
            (output): ViTSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
          )
          (intermediate): ViTIntermediate(
            (dense): Linear(in_fe

In [10]:
model.eval()

ViTForImageClassification(
  (vit): ViTModel(
    (embeddings): ViTEmbeddings(
      (patch_embeddings): ViTPatchEmbeddings(
        (projection): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
      )
      (dropout): Dropout(p=0.0, inplace=False)
    )
    (encoder): ViTEncoder(
      (layer): ModuleList(
        (0-11): 12 x ViTLayer(
          (attention): ViTSdpaAttention(
            (attention): ViTSdpaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
            (output): ViTSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
          )
          (intermediate): ViTIntermediate(
            (dense): Linear(in_fe

In [11]:
logits = base.generate_logits(train_dataloader, model)
logits_aug = base.generate_logits(train_dataloader_aug, model)

logits_eval = base.generate_logits(eval_dataloder, model)
logits_test = base.generate_logits(test_dataloader, model)

  0%|          | 0/313 [00:00<?, ?it/s]

  0%|          | 0/313 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

In [12]:
data_file = base.unpickle("data/100/cifar-100-python/train")
testing = base.unpickle("data/100/cifar-100-python/test")

In [13]:
data = {key: [value[i] for i in train_idx] for key, value in data_file.items() if key != b"batch_label"}    
eval_data = {key: [value[i] for i in validation_idx] for key, value in data_file.items() if key != b"batch_label"}    

In [14]:
data[b"logits"] = logits
data[b"logits_aug"] = logits_aug

eval_data[b"logits"] = logits_eval
testing[b"logits"] = logits_test

In [15]:
base.pickle_up("data/100-logits/cifar-100-python/test", testing)
base.pickle_up("data/100-logits/cifar-100-python/train", data)
base.pickle_up("data/100-logits/cifar-100-python/eval", eval_data)

In [16]:
dataset_part = base.get_dataset_part()

In [17]:
train_aug = base.CustomCIFAR100L(root='./data/100-logits', dataset_part=dataset_part.TRAIN, transform=augment_transform)
train = base.CustomCIFAR100L(root='./data/100-logits', dataset_part=dataset_part.TRAIN, transform=transform)
train_combo = ConcatDataset([train, train_aug])

In [18]:
print(base.check_acc(train))
print(base.check_acc(train_aug))
print(base.check_acc(train_combo))

  0%|          | 0/40000 [00:00<?, ?it/s]

Accuracy for given set is: 0.94035


  0%|          | 0/40000 [00:00<?, ?it/s]

Accuracy for given set is: 0.6315


  0%|          | 0/80000 [00:00<?, ?it/s]

Accuracy for given set is: 0.785925


In [19]:
train_aug = base.remove_diff_pred_class(train, train_aug)
train_combo = ConcatDataset([train, train_aug])

In [20]:
print(base.check_acc(train_aug))
print(base.check_acc(train_combo))

  0%|          | 0/25912 [00:00<?, ?it/s]

Accuracy for given set is: 0.9609833281877123


  0%|          | 0/65912 [00:00<?, ?it/s]

Accuracy for given set is: 0.9484615851438282
