In [10]:
from datasets import load_dataset
from datasets import DatasetDict

raw_datasets = load_dataset("microsoft/cats_vs_dogs")
print("Available splits:", raw_datasets.keys())

# If a validation split is not provided, create one from the training data.
if "validation" not in raw_datasets:
    train_val = raw_datasets["train"].train_test_split(test_size=0.1, seed=42)
    train_datasets = train_val["train"]
    eval_datasets = train_val["test"]
else:
    train_datasets = raw_datasets["train"]
    eval_datasets = raw_datasets["validation"]

# Check for a test split; if not available, set test_dataset to None.
if "test" in raw_datasets:
    test_datasets = raw_datasets["test"]
else:
    test_datasets = None

print("Training samples:", len(train_datasets))
print("Validation samples:", len(eval_datasets))
if test_datasets is not None:
    print("Test samples:", len(test_datasets))

train_num = train_datasets.num_rows
valid_num = eval_datasets.num_rows


fraction = 1/10

train_sample_size = int(train_num * fraction)
valid_sample_size = int(valid_num * fraction)

train_subset = train_datasets.shuffle(seed=42).select(range(train_sample_size))
valid_subset = eval_datasets.shuffle(seed=42).select(range(valid_sample_size))


ds = DatasetDict({
    "train": train_subset,
    "validation": valid_subset
})

print("Original train size:", train_num)
print("Original validation size:", valid_num)
print("Reduced train size:", ds["train"].num_rows)
print("Reduced validation size:", ds["validation"].num_rows)


Available splits: dict_keys(['train'])
Training samples: 21069
Validation samples: 2341
Original train size: 21069
Original validation size: 2341
Reduced train size: 2106
Reduced validation size: 234


In [11]:
print(ds)

print(ds['train'][1])

DatasetDict({
    train: Dataset({
        features: ['image', 'labels'],
        num_rows: 2106
    })
    validation: Dataset({
        features: ['image', 'labels'],
        num_rows: 234
    })
})
{'image': <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=306x220 at 0x1A8FD4D2B10>, 'labels': 1}


In [12]:
ds['validation'][0]

{'image': <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=500x375>,
 'labels': 1}

In [13]:
from torchvision.transforms import Compose, RandomResizedCrop, RandomHorizontalFlip, ColorJitter, ToTensor, Normalize


transform = Compose([
    RandomResizedCrop(224),
    RandomHorizontalFlip(),
    ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
    ToTensor(),
    Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])


def preprocess(example):
    # Check and handle unexpected data types
    if isinstance(example['image'], list):
        example['pixel_values'] = [transform(img) for img in example['image']]
    else:
        example['pixel_values'] = transform(example['image'])
    
    return example

ds.reset_format()  # Ensure dataset is in the original format
ds = ds.with_transform(preprocess)


In [14]:
sample = ds['train'][0]

In [15]:
import torch
from torch.utils.data import DataLoader

def collate_fn(examples):
    pixel_values = torch.stack([example['pixel_values'] for example in examples])
    labels = torch.tensor([example['labels'] for example in examples])
    return {'pixel_values': pixel_values, 'labels': labels}

train_loader = DataLoader(ds['train'], batch_size=32, shuffle=True, collate_fn=collate_fn)
val_loader = DataLoader(ds['validation'], batch_size=32, shuffle=False, collate_fn=collate_fn)


In [16]:
from transformers import AutoImageProcessor, ViTForImageClassification


model_name = "google/vit-base-patch16-224-in21k"
model = ViTForImageClassification.from_pretrained(
    model_name,
    num_labels=2,  
)

model = model.to("cuda" if torch.cuda.is_available() else "cpu")


Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224-in21k and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [18]:
from torch.cuda.amp import GradScaler, autocast
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.optim import AdamW
from transformers import AutoModelForImageClassification, AutoImageProcessor
from torch.utils.data import DataLoader


device = "cuda" if torch.cuda.is_available() else "cpu"
scaler = GradScaler()


model = AutoModelForImageClassification.from_pretrained("google/vit-base-patch16-224-in21k").to(device)
optimizer = AdamW(model.parameters(), lr=0.0002)
criterion = torch.nn.CrossEntropyLoss()
scheduler = CosineAnnealingLR(optimizer, T_max=10)

epochs = 5
train_batch_size = 128
eval_batch_size = 128
log_every = 10  


results = {"epoch": [], "step": [], "train_loss": [], "val_loss": [], "accuracy": []}

for epoch in range(1, epochs + 1):
    print(f"Epoch {epoch}/{epochs}")
    model.train()
    total_train_loss = 0
    total_steps = 0

    for batch_idx, batch in enumerate(train_loader):
        inputs = batch["pixel_values"].to(device)
        labels = batch["labels"].to(device)

        optimizer.zero_grad()

        with autocast():
            outputs = model(inputs)
            loss = criterion(outputs.logits, labels)

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        total_train_loss += loss.item()
        total_steps += 1

        
        if batch_idx % log_every == 0:
            print(f"Batch {batch_idx}/{len(train_loader)}, Loss: {loss.item():.4f}")

    avg_train_loss = total_train_loss / total_steps
    print(f"Epoch {epoch} Completed. Average Training Loss: {avg_train_loss:.4f}")


    model.eval()
    total_val_loss = 0
    correct = 0
    total = 0

    with torch.no_grad():
        for batch in val_loader:
            inputs = batch["pixel_values"].to(device)
            labels = batch["labels"].to(device)

            outputs = model(inputs)
            loss = criterion(outputs.logits, labels)
            total_val_loss += loss.item()

            _, preds = torch.max(outputs.logits, 1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)

    avg_val_loss = total_val_loss / len(val_loader)
    accuracy = correct / total
    print(f"Validation Loss: {avg_val_loss:.4f}, Accuracy: {accuracy:.4f}")

    
    results["epoch"].append(epoch)
    results["step"].append(total_steps)
    results["train_loss"].append(avg_train_loss)
    results["val_loss"].append(avg_val_loss)
    results["accuracy"].append(accuracy)


print("\nTraining Summary:")
print(f"{'Epoch':<10}{'Training Loss':<15}{'Validation Loss':<20}{'Accuracy':<10}")
for i in range(epochs):
    print(f"{results['epoch'][i]:<10}{results['train_loss'][i]:<15.4f}{results['val_loss'][i]:<20.4f}{results['accuracy'][i]:<10.4f}")


  scaler = GradScaler()
Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224-in21k and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Epoch 1/5


  with autocast():


Batch 0/66, Loss: 0.6902
Batch 10/66, Loss: 0.2080
Batch 20/66, Loss: 0.2022
Batch 30/66, Loss: 0.1745
Batch 40/66, Loss: 0.1785
Batch 50/66, Loss: 0.1514
Batch 60/66, Loss: 0.2556
Epoch 1 Completed. Average Training Loss: 0.2064
Validation Loss: 0.0700, Accuracy: 0.9701
Epoch 2/5
Batch 0/66, Loss: 0.0923
Batch 10/66, Loss: 0.1762
Batch 20/66, Loss: 0.2810
Batch 30/66, Loss: 0.0957
Batch 40/66, Loss: 0.2374
Batch 50/66, Loss: 0.0403
Batch 60/66, Loss: 0.0876
Epoch 2 Completed. Average Training Loss: 0.1287
Validation Loss: 0.1142, Accuracy: 0.9573
Epoch 3/5
Batch 0/66, Loss: 0.0756
Batch 10/66, Loss: 0.1222
Batch 20/66, Loss: 0.1085
Batch 30/66, Loss: 0.0511
Batch 40/66, Loss: 0.2304
Batch 50/66, Loss: 0.2138
Batch 60/66, Loss: 0.0805
Epoch 3 Completed. Average Training Loss: 0.1163
Validation Loss: 0.1244, Accuracy: 0.9444
Epoch 4/5
Batch 0/66, Loss: 0.1408
Batch 10/66, Loss: 0.1315
Batch 20/66, Loss: 0.2643
Batch 30/66, Loss: 0.1977
Batch 40/66, Loss: 0.1331
Batch 50/66, Loss: 0.0441

In [19]:
print(batch.keys())

dict_keys(['pixel_values', 'labels'])


In [20]:
model.eval()
correct = 0
total = 0
with torch.no_grad():
    for batch in val_loader:
        
        images = batch["pixel_values"].to(device)
        labels = batch["labels"].to(device)
        
        outputs = model(images)
        logits = outputs.logits  
        
        _, predicted = torch.max(logits, 1)
        
        
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print(f"Validation Accuracy: {100 * correct / total:.2f}%")

Validation Accuracy: 94.87%
