In [76]:
from __future__ import print_function
from __future__ import division
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import torchvision
from datasets import load_dataset
from torchvision import datasets, models, transforms
from functools import partial
import matplotlib.pyplot as plt
import time
import os
import copy

In [77]:
print("PyTorch Version: ", torch.__version__)
print("Torchvision Version: ", torchvision.__version__)

PyTorch Version:  2.9.1+cu128
Torchvision Version:  0.24.1+cu128


In [78]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device

device(type='cuda', index=0)

In [79]:
input_size = 224
data_transforms = {
    'train': transforms.Compose([
        transforms.RandomResizedCrop(input_size),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'val': transforms.Compose([
        transforms.Resize(input_size),
        transforms.CenterCrop(input_size),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
}

In [97]:
def transforms_fn(example_batch, split):
    """Apply _train_transforms across a batch."""
    example_batch["pixel_values"] = [
        data_transforms[split](pil_img.convert("RGB")) for pil_img in example_batch['image']
    ]
    return example_batch

In [81]:
dataset = load_dataset("HichTala/coco-background")

In [82]:
dataset["train"].set_transform(partial(transforms_fn, split="train"))
dataset["validation"].set_transform(partial(transforms_fn, "val"))

In [83]:
def collate_fn(examples):
    pixel_values = torch.stack([example["pixel_values"] for example in examples])
    labels = torch.tensor([example["label"] for example in examples])
    return {"pixel_values": pixel_values, "labels": labels}

In [84]:
dataloaders_dict = {
    split_name: torch.utils.data.DataLoader(dataset[split_name], batch_size=8, shuffle=True, num_workers=4, collate_fn=collate_fn)
    for split_name in ['train', 'validation']
}

In [85]:
dataloaders_dict['train']

<torch.utils.data.dataloader.DataLoader at 0x7bbc1ce6d820>

In [86]:
dataset["train"].column_names

['image', 'label']

In [87]:
dataloader = torch.utils.data.DataLoader(dataset['train'], batch_size=8, shuffle=True, num_workers=4, collate_fn=collate_fn)

In [88]:
next(iter(dataloaders_dict['train']))

{'image': [<PIL.PngImagePlugin.PngImageFile image mode=RGB size=4x4 at 0x7BBC1CE6D6A0>, <PIL.PngImagePlugin.PngImageFile image mode=RGB size=4x9 at 0x7BBC1CE6D6D0>, <PIL.PngImagePlugin.PngImageFile image mode=RGB size=23x19 at 0x7BBC1CE6CCE0>, <PIL.PngImagePlugin.PngImageFile image mode=RGB size=20x55 at 0x7BBC1CE6C8F0>, <PIL.PngImagePlugin.PngImageFile image mode=RGB size=297x190 at 0x7BBC1CE6D1F0>, <PIL.PngImagePlugin.PngImageFile image mode=RGB size=107x67 at 0x7BBC1CE6CC20>, <PIL.PngImagePlugin.PngImageFile image mode=RGB size=425x186 at 0x7BBC1CE6C890>, <PIL.PngImagePlugin.PngImageFile image mode=RGB size=76x155 at 0x7BBC1CE6CF50>], 'label': [63, 50, 6, 50, 19, 13, 2, 50]}
{'image': [<PIL.PngImagePlugin.PngImageFile image mode=RGB size=171x640 at 0x7BBC1CDD7140>, <PIL.PngImagePlugin.PngImageFile image mode=RGB size=20x15 at 0x7BBC1CDD4BC0>, <PIL.PngImagePlugin.PngImageFile image mode=RGB size=9x16 at 0x7BBC1CDD42C0>, <PIL.PngImagePlugin.PngImageFile image mode=RGB size=40x72 at 0x

{'pixel_values': tensor([[[[-1.2788, -1.2788, -1.2788,  ..., -1.3302, -1.3302, -1.3302],
           [-1.2788, -1.2788, -1.2788,  ..., -1.3302, -1.3302, -1.3302],
           [-1.2788, -1.2788, -1.2788,  ..., -1.3302, -1.3302, -1.3302],
           ...,
           [-1.3473, -1.3473, -1.3473,  ..., -0.6281, -0.6281, -0.6281],
           [-1.3473, -1.3473, -1.3473,  ..., -0.6281, -0.6281, -0.6281],
           [-1.3473, -1.3473, -1.3473,  ..., -0.6281, -0.6281, -0.6281]],
 
          [[-1.1078, -1.1078, -1.1078,  ..., -1.1954, -1.1954, -1.1954],
           [-1.1078, -1.1078, -1.1078,  ..., -1.1954, -1.1954, -1.1954],
           [-1.1078, -1.1078, -1.1078,  ..., -1.1954, -1.1954, -1.1954],
           ...,
           [-1.3179, -1.3179, -1.3179,  ..., -0.4951, -0.4951, -0.4951],
           [-1.3179, -1.3179, -1.3179,  ..., -0.4951, -0.4951, -0.4951],
           [-1.3179, -1.3179, -1.3179,  ..., -0.4951, -0.4951, -0.4951]],
 
          [[-1.3513, -1.3513, -1.3513,  ..., -1.1944, -1.1944, -1.1944

In [96]:
for sample in dataloaders_dict['train']:
    inputs = sample['pixel_values']
    labels = sample['labels']
    break

{'image': [<PIL.PngImagePlugin.PngImageFile image mode=RGB size=37x16 at 0x7BBB19186840>, <PIL.PngImagePlugin.PngImageFile image mode=RGB size=40x92 at 0x7BBB19186D80>, <PIL.PngImagePlugin.PngImageFile image mode=RGB size=63x44 at 0x7BBB191842F0>, <PIL.PngImagePlugin.PngImageFile image mode=RGB size=187x67 at 0x7BBB191850A0>, <PIL.PngImagePlugin.PngImageFile image mode=RGB size=23x34 at 0x7BBB19185E20>, <PIL.PngImagePlugin.PngImageFile image mode=RGB size=36x37 at 0x7BBB19185190>, <PIL.PngImagePlugin.PngImageFile image mode=RGB size=23x56 at 0x7BBB19184F80>, <PIL.PngImagePlugin.PngImageFile image mode=RGB size=18x43 at 0x7BBB19185370>], 'label': [59, 12, 9, 66, 50, 67, 50, 37]}
{'image': [<PIL.PngImagePlugin.PngImageFile image mode=RGB size=24x36 at 0x7BBB19186840>, <PIL.PngImagePlugin.PngImageFile image mode=RGB size=19x23 at 0x7BBB19186D80>, <PIL.PngImagePlugin.PngImageFile image mode=RGB size=9x28 at 0x7BBB19186D50>, <PIL.PngImagePlugin.PngImageFile image mode=RGB size=43x62 at 0x7B

In [93]:
sample = next(iter(dataloaders_dict['train']))

{'image': [<PIL.PngImagePlugin.PngImageFile image mode=RGB size=17x32 at 0x7BBB19186390>, <PIL.PngImagePlugin.PngImageFile image mode=RGB size=6x11 at 0x7BBB19185F10>, <PIL.PngImagePlugin.PngImageFile image mode=RGB size=13x10 at 0x7BBB19185670>, <PIL.PngImagePlugin.PngImageFile image mode=RGB size=28x48 at 0x7BBB19185340>, <PIL.PngImagePlugin.PngImageFile image mode=RGB size=640x141 at 0x7BBB191840E0>, <PIL.PngImagePlugin.PngImageFile image mode=RGB size=27x29 at 0x7BBB19186270>, <PIL.PngImagePlugin.PngImageFile image mode=RGB size=25x85 at 0x7BBB19185430>, <PIL.PngImagePlugin.PngImageFile image mode=RGB size=34x30 at 0x7BBB19185190>], 'label': [26, 73, 73, 50, 2, 16, 50, 26]}{'image': [<PIL.PngImagePlugin.PngImageFile image mode=RGB size=30x44 at 0x7BBB191867B0>, <PIL.PngImagePlugin.PngImageFile image mode=RGB size=29x47 at 0x7BBB191868A0>, <PIL.PngImagePlugin.PngImageFile image mode=RGB size=9x27 at 0x7BBB19186D80>, <PIL.PngImagePlugin.PngImageFile image mode=RGB size=44x27 at 0x7BB

In [98]:
inputs.shape

torch.Size([8, 3, 224, 224])