In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [None]:
!pip install transformers

In [None]:
import torch
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
import numpy as np
import torchvision
import torch.optim as optim
import torch.nn.functional as F
from tqdm import tqdm

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

In [None]:
transform = transforms.Compose([
    transforms.Resize((224, 224)),  # Resize CIFAR-10 images to 224x224 to match DeiT input size
    transforms.ToTensor(),
    # Normalize using ImageNet mean and std
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Download and load CIFAR-10 training dataset
train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=2)

# Download and load CIFAR-10 test dataset
test_dataset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers=2)

In [None]:
from transformers import DeiTForImageClassificationWithTeacher

baseline_model = DeiTForImageClassificationWithTeacher.from_pretrained('facebook/deit-base-distilled-patch16-224')
baseline_model = baseline_model.to(device)

In [None]:
baseline_model.distillation_classifier = torch.nn.Linear(in_features=baseline_model.distillation_classifier.in_features, out_features=10)
baseline_model.cls_classifier = torch.nn.Linear(in_features=baseline_model.cls_classifier.in_features, out_features=10)

In [None]:
# After modifying the model's layers, ensure they are on the correct device
baseline_model.distillation_classifier.to(device)
baseline_model.cls_classifier.to(device)

# Before starting the training loop, you can add a check like this
for name, param in baseline_model.named_parameters():
    print(f"{name} is on {param.device}")

# Ensure inputs are on the correct device right before the forward pass
inputs, targets = inputs.to(device), targets.to(device)

# Then proceed with your training loop


In [None]:
from transformers import DeiTForImageClassificationWithTeacher
import torch.nn.functional as F

def train_one_epoch(epoch, model, train_loader, optimizer, device):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    
    progress_bar = tqdm(enumerate(train_loader), total=len(train_loader))
    for batch_idx, (inputs, targets) in progress_bar:
        inputs, targets = inputs.to(device), targets.to(device)
        
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = F.cross_entropy(outputs.logits, targets)  # Corrected to use outputs.logits
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        _, predicted = outputs.logits.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()

        progress_bar.set_description(f'Epoch {epoch} Loss: {running_loss/(batch_idx+1):.3f} Acc: {100.*correct/total:.3f}%')
    
    return running_loss / len(train_loader), 100.*correct / total


In [None]:
def validate(model, test_loader, device):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for inputs, targets in test_loader:
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            _, predicted = outputs.logits.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()
    
    acc = 100.*correct / total
    print(f'Validation Accuracy: {acc:.3f}%')
    return acc

In [None]:
optimizer = optim.Adam(baseline_model.parameters(), lr=0.001)
max_epoch = 16

In [None]:
for epoch in range(1,17):
    print(f"Epoch {epoch}: Normal Training")
    train_loss, train_acc = train_one_epoch(epoch, baseline_model, train_loader, optimizer, device)
    print(f"Training Loss: {train_loss}, Training Accuracy: {train_acc}")
    val_acc = validate(baseline_model, test_loader, device)
    print(f"Validation Accuracy: {val_acc}%")
    