In [None]:
import torch
from torchvision import datasets, transforms
from torch.utils.data import Dataset
from typing import Optional

DATASETS_DIR = "data"

class GPUDataset(Dataset):
    def __init__(self, dataset: Dataset, transform: Optional[transforms.Compose], device: str):
        self.device = device
        self.transform = transform

        self.data = torch.empty((len(dataset), *transform(dataset[0][0]).shape), device=device)
        self.targets = torch.empty(len(dataset), dtype=torch.long, device=device)

        for i, (img, target) in enumerate(dataset):
            self.data[i] = transform(img).to(device)
            self.targets[i] = torch.tensor(target).to(device)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx], self.targets[idx]


def get_dataset(
    dataset_name: str = "FashionMNIST", transform: Optional[transforms.Compose] = None
) -> tuple[Dataset, Dataset]:
    if dataset_name == "FashionMNIST":
        train_dataset = datasets.FashionMNIST(root=DATASETS_DIR, train=True, download=True)
        test_dataset = datasets.FashionMNIST(root=DATASETS_DIR, train=False, download=True)
    elif dataset_name == "CIFAR100":
        train_dataset = datasets.CIFAR100(
            root=DATASETS_DIR, train=True, download=True, transform=transform
        )
        test_dataset = datasets.CIFAR100(
            root=DATASETS_DIR, train=False, download=True, transform=transform
        )

    return train_dataset, test_dataset


In [None]:
import torch
from torchvision import transforms
from torchvision.models import MobileNet_V3_Small_Weights
from torch.utils.data import DataLoader, Dataset, Subset, random_split
import numpy as np
from typing import Optional
import os


def get_dataloaders(
    dataset_name: str = "FashionMNIST",
    batch_size: int = 32,
    val_split: float = 0.2,
    seed: int = 42,
    num_workers: int = 0,
    transform: Optional[transforms.Compose] = None,
    device: str = "cuda",
) -> tuple[Dataset, DataLoader, DataLoader, DataLoader]:

    torch.manual_seed(seed)

    if dataset_name == "FashionMNIST":
        train_dataset, test_dataset = get_dataset(dataset_name=dataset_name)
        transform = transforms.Compose(
            [
                transforms.ToTensor(),
                transforms.Normalize((0.5,), (0.5,)),
            ]
        )
    elif dataset_name == "CIFAR100":
        train_dataset, test_dataset = get_dataset(dataset_name=dataset_name, transform=transform)
        weights = MobileNet_V3_Small_Weights.DEFAULT
        transform = weights.transforms()
    else:
        raise ValueError(f"Unknown dataset name: {dataset_name}")

    val_size = int(val_split * len(train_dataset))
    train_size = len(train_dataset) - val_size
    train_dataset, val_dataset = random_split(train_dataset, [train_size, val_size])

    if dataset_name == "FashionMNIST":
        train_dataset = GPUDataset(train_dataset, transform, device)
        val_dataset = GPUDataset(val_dataset, transform, device)
        test_dataset = GPUDataset(test_dataset, transform, device)

    train_dataloader = DataLoader(
        train_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers
    )
    val_dataloader = DataLoader(
        val_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers
    )
    test_dataloader = DataLoader(
        test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers
    )

    return train_dataset, train_dataloader, val_dataloader, test_dataloader


def get_subset_loader(
    dataset: Dataset, mask: np.ndarray, batch_size: int = 32, num_workers: int = 0
) -> DataLoader:
    assert mask.dtype == bool, "Mask must be a boolean array."
    assert len(mask) == len(dataset), "Mask and dataset must have the same length."
    indices = np.where(mask)[0]
    subset = Subset(dataset, indices)
    return DataLoader(subset, batch_size=batch_size, shuffle=True, num_workers=num_workers)



num_workers = 0
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dataset_name = "FashionMNIST"
train_dataset, train_dataloader, val_dataloader, test_dataloader = get_dataloaders(
    dataset_name=dataset_name,
    num_workers=num_workers,
    device=device,
)

total_samples = len(train_dataset)
mask = np.zeros(total_samples, dtype=bool)
mask[:500] = True

subset_loader = get_subset_loader(train_dataset, mask, num_workers=num_workers)

print(f"Train dataloader: {len(train_dataloader)}")
print(f"Validation dataloader: {len(val_dataloader)}")
print(f"Test dataloader: {len(test_dataloader)}")
print(f"Subset dataloader: {len(subset_loader)}")

if dataset_name == "FashionMNIST":
    print(subset_loader.dataset.dataset.data.device)
    print(subset_loader.dataset.dataset.targets.device)


Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to data/FashionMNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 26.4M/26.4M [00:01<00:00, 14.3MB/s]


Extracting data/FashionMNIST/raw/train-images-idx3-ubyte.gz to data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz to data/FashionMNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 29.5k/29.5k [00:00<00:00, 114kB/s]


Extracting data/FashionMNIST/raw/train-labels-idx1-ubyte.gz to data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz to data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 4.42M/4.42M [00:01<00:00, 4.24MB/s]


Extracting data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz to data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz to data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 5.15k/5.15k [00:00<00:00, 4.50MB/s]


Extracting data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz to data/FashionMNIST/raw

Train dataloader: 1500
Validation dataloader: 375
Test dataloader: 313
Subset dataloader: 16
cpu
cpu


In [None]:
from torch import optim
from tqdm import tqdm
import torch.nn as nn

def train_model(
    model: nn.Module,
    train_loader: DataLoader,
    num_epochs: int = 2,
    learning_rate: float = 1e-3,
):

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)

    for epoch in tqdm(range(num_epochs)):
        model.train()
        train_loss = 0
        for images, labels in train_loader:

            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            train_loss += loss.item()


def validate_model(model: nn.Module, val_loader: DataLoader):

    model.eval()
    val_loss = 0
    correct = 0
    total = 0

    criterion = nn.CrossEntropyLoss()

    with torch.inference_mode():
        for images, labels in val_loader:
            output = model(images)
            loss = criterion(output, labels)
            val_loss += loss.item()
            _, predicted = torch.max(output.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    accuracy = correct / total
    correct
    return val_loss, accuracy

In [None]:
import torch.nn.functional as F

class SimpleFCN(nn.Module):
    def __init__(self):
        super(SimpleFCN, self).__init__()
        self.fc1 = nn.Linear(28 * 28, 128)
        self.fc2 = nn.Linear(128, 64)
        self.fc3 = nn.Linear(64, 10)

    def forward(self, x):
        x = x.view(x.size(0), -1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x
model = SimpleFCN()

In [None]:
percentage = range(5, 101, 5)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
for i in percentage:
    result = 0
    acc = 0
    for j in range(10):
        subsample = int(len(train_dataset) * i / 100)
        mask = np.zeros(len(train_dataset), dtype=bool)
        random_indices = np.random.choice(len(train_dataset), subsample, replace=False)
        mask[random_indices] = True

        subset_loader = get_subset_loader(train_dataset, mask, num_workers=num_workers)
        model = SimpleFCN().to(device)
        train_model(model, subset_loader)
        loss1, acc1 = validate_model(model, val_dataloader)
        result += loss1
        acc += acc1
    print(f"Percentage: {i} , loss={result/10} acc={acc/10}")



100%|██████████| 2/2 [00:00<00:00,  4.56it/s]
100%|██████████| 2/2 [00:00<00:00,  5.88it/s]
100%|██████████| 2/2 [00:00<00:00,  5.75it/s]
100%|██████████| 2/2 [00:00<00:00,  5.79it/s]
100%|██████████| 2/2 [00:00<00:00,  5.93it/s]
100%|██████████| 2/2 [00:00<00:00,  5.42it/s]
100%|██████████| 2/2 [00:00<00:00,  5.71it/s]
100%|██████████| 2/2 [00:00<00:00,  5.78it/s]
100%|██████████| 2/2 [00:00<00:00,  5.46it/s]
100%|██████████| 2/2 [00:00<00:00,  5.90it/s]


Percentage: 5 , loss=234.2657291650772 acc=0.7690333333333332


100%|██████████| 2/2 [00:00<00:00,  2.02it/s]
100%|██████████| 2/2 [00:01<00:00,  1.70it/s]
100%|██████████| 2/2 [00:00<00:00,  2.71it/s]
100%|██████████| 2/2 [00:00<00:00,  3.01it/s]
100%|██████████| 2/2 [00:00<00:00,  2.83it/s]
100%|██████████| 2/2 [00:00<00:00,  3.02it/s]
100%|██████████| 2/2 [00:00<00:00,  3.06it/s]
100%|██████████| 2/2 [00:00<00:00,  3.02it/s]
100%|██████████| 2/2 [00:00<00:00,  2.99it/s]
100%|██████████| 2/2 [00:00<00:00,  3.03it/s]


Percentage: 10 , loss=214.3893195167184 acc=0.7912416666666667


100%|██████████| 2/2 [00:01<00:00,  1.94it/s]
100%|██████████| 2/2 [00:01<00:00,  1.95it/s]
100%|██████████| 2/2 [00:01<00:00,  1.41it/s]
100%|██████████| 2/2 [00:01<00:00,  1.81it/s]
100%|██████████| 2/2 [00:00<00:00,  2.00it/s]
100%|██████████| 2/2 [00:00<00:00,  2.06it/s]
100%|██████████| 2/2 [00:01<00:00,  1.99it/s]
100%|██████████| 2/2 [00:00<00:00,  2.00it/s]
100%|██████████| 2/2 [00:00<00:00,  2.03it/s]
100%|██████████| 2/2 [00:00<00:00,  2.03it/s]


Percentage: 15 , loss=190.65281866043807 acc=0.8166083333333335


100%|██████████| 2/2 [00:01<00:00,  1.41it/s]
100%|██████████| 2/2 [00:01<00:00,  1.04it/s]
100%|██████████| 2/2 [00:01<00:00,  1.49it/s]
100%|██████████| 2/2 [00:01<00:00,  1.41it/s]
100%|██████████| 2/2 [00:01<00:00,  1.50it/s]
100%|██████████| 2/2 [00:01<00:00,  1.50it/s]
100%|██████████| 2/2 [00:01<00:00,  1.53it/s]
100%|██████████| 2/2 [00:01<00:00,  1.49it/s]
100%|██████████| 2/2 [00:01<00:00,  1.06it/s]
100%|██████████| 2/2 [00:01<00:00,  1.34it/s]


Percentage: 20 , loss=186.9966082006693 acc=0.8178000000000001


100%|██████████| 2/2 [00:01<00:00,  1.17it/s]
100%|██████████| 2/2 [00:01<00:00,  1.17it/s]
100%|██████████| 2/2 [00:01<00:00,  1.10it/s]
100%|██████████| 2/2 [00:01<00:00,  1.17it/s]
100%|██████████| 2/2 [00:02<00:00,  1.12s/it]
100%|██████████| 2/2 [00:01<00:00,  1.12it/s]
100%|██████████| 2/2 [00:01<00:00,  1.14it/s]
100%|██████████| 2/2 [00:01<00:00,  1.16it/s]
100%|██████████| 2/2 [00:01<00:00,  1.21it/s]
100%|██████████| 2/2 [00:01<00:00,  1.02it/s]


Percentage: 25 , loss=176.4296816572547 acc=0.8295833333333332


100%|██████████| 2/2 [00:02<00:00,  1.44s/it]
100%|██████████| 2/2 [00:02<00:00,  1.03s/it]
100%|██████████| 2/2 [00:02<00:00,  1.01s/it]
100%|██████████| 2/2 [00:02<00:00,  1.01s/it]
100%|██████████| 2/2 [00:02<00:00,  1.02s/it]
100%|██████████| 2/2 [00:02<00:00,  1.42s/it]
100%|██████████| 2/2 [00:02<00:00,  1.15s/it]
100%|██████████| 2/2 [00:02<00:00,  1.03s/it]
100%|██████████| 2/2 [00:02<00:00,  1.01s/it]
100%|██████████| 2/2 [00:02<00:00,  1.06s/it]


Percentage: 30 , loss=174.43146554231643 acc=0.8311249999999999


100%|██████████| 2/2 [00:03<00:00,  1.57s/it]
100%|██████████| 2/2 [00:02<00:00,  1.20s/it]
100%|██████████| 2/2 [00:02<00:00,  1.19s/it]
100%|██████████| 2/2 [00:02<00:00,  1.16s/it]
100%|██████████| 2/2 [00:03<00:00,  1.56s/it]
100%|██████████| 2/2 [00:02<00:00,  1.27s/it]
100%|██████████| 2/2 [00:02<00:00,  1.23s/it]
100%|██████████| 2/2 [00:02<00:00,  1.18s/it]
100%|██████████| 2/2 [00:02<00:00,  1.31s/it]
100%|██████████| 2/2 [00:02<00:00,  1.44s/it]


Percentage: 35 , loss=171.3694272786379 acc=0.8361666666666666


100%|██████████| 2/2 [00:02<00:00,  1.40s/it]
100%|██████████| 2/2 [00:02<00:00,  1.42s/it]
100%|██████████| 2/2 [00:03<00:00,  1.52s/it]
100%|██████████| 2/2 [00:03<00:00,  1.61s/it]
100%|██████████| 2/2 [00:02<00:00,  1.37s/it]
100%|██████████| 2/2 [00:02<00:00,  1.40s/it]
100%|██████████| 2/2 [00:03<00:00,  1.59s/it]
100%|██████████| 2/2 [00:03<00:00,  1.58s/it]
100%|██████████| 2/2 [00:02<00:00,  1.43s/it]
100%|██████████| 2/2 [00:02<00:00,  1.34s/it]


Percentage: 40 , loss=164.36263493523 acc=0.8428833333333333


100%|██████████| 2/2 [00:04<00:00,  2.01s/it]
100%|██████████| 2/2 [00:03<00:00,  1.55s/it]
100%|██████████| 2/2 [00:03<00:00,  1.57s/it]
100%|██████████| 2/2 [00:03<00:00,  1.66s/it]
100%|██████████| 2/2 [00:03<00:00,  1.85s/it]
100%|██████████| 2/2 [00:03<00:00,  1.56s/it]
100%|██████████| 2/2 [00:03<00:00,  1.57s/it]
100%|██████████| 2/2 [00:03<00:00,  1.98s/it]
100%|██████████| 2/2 [00:03<00:00,  1.55s/it]
100%|██████████| 2/2 [00:03<00:00,  1.58s/it]


Percentage: 45 , loss=165.10329641774297 acc=0.84005


100%|██████████| 2/2 [00:04<00:00,  2.07s/it]
100%|██████████| 2/2 [00:03<00:00,  1.86s/it]
100%|██████████| 2/2 [00:03<00:00,  1.75s/it]
100%|██████████| 2/2 [00:03<00:00,  1.82s/it]
100%|██████████| 2/2 [00:04<00:00,  2.02s/it]
100%|██████████| 2/2 [00:03<00:00,  1.76s/it]
100%|██████████| 2/2 [00:03<00:00,  1.67s/it]
100%|██████████| 2/2 [00:04<00:00,  2.15s/it]
100%|██████████| 2/2 [00:03<00:00,  1.80s/it]
100%|██████████| 2/2 [00:03<00:00,  1.75s/it]


Percentage: 50 , loss=157.77378578707575 acc=0.8467166666666668


100%|██████████| 2/2 [00:04<00:00,  2.34s/it]
100%|██████████| 2/2 [00:03<00:00,  1.94s/it]
100%|██████████| 2/2 [00:03<00:00,  1.90s/it]
100%|██████████| 2/2 [00:04<00:00,  2.24s/it]
100%|██████████| 2/2 [00:03<00:00,  1.90s/it]
100%|██████████| 2/2 [00:04<00:00,  2.11s/it]
100%|██████████| 2/2 [00:04<00:00,  2.09s/it]
100%|██████████| 2/2 [00:04<00:00,  2.09s/it]
100%|██████████| 2/2 [00:04<00:00,  2.19s/it]
100%|██████████| 2/2 [00:03<00:00,  2.00s/it]


Percentage: 55 , loss=159.19190081655978 acc=0.8457250000000001


100%|██████████| 2/2 [00:04<00:00,  2.13s/it]
100%|██████████| 2/2 [00:04<00:00,  2.43s/it]
100%|██████████| 2/2 [00:04<00:00,  2.03s/it]
100%|██████████| 2/2 [00:04<00:00,  2.14s/it]
100%|██████████| 2/2 [00:04<00:00,  2.45s/it]
100%|██████████| 2/2 [00:04<00:00,  2.14s/it]
100%|██████████| 2/2 [00:04<00:00,  2.48s/it]
100%|██████████| 2/2 [00:04<00:00,  2.19s/it]
100%|██████████| 2/2 [00:04<00:00,  2.04s/it]
100%|██████████| 2/2 [00:05<00:00,  2.57s/it]


Percentage: 60 , loss=156.68749888017774 acc=0.8488666666666667


100%|██████████| 2/2 [00:04<00:00,  2.27s/it]
100%|██████████| 2/2 [00:05<00:00,  2.55s/it]
100%|██████████| 2/2 [00:04<00:00,  2.41s/it]
100%|██████████| 2/2 [00:04<00:00,  2.30s/it]
100%|██████████| 2/2 [00:05<00:00,  2.62s/it]
100%|██████████| 2/2 [00:04<00:00,  2.29s/it]
100%|██████████| 2/2 [00:05<00:00,  2.54s/it]
100%|██████████| 2/2 [00:04<00:00,  2.29s/it]
100%|██████████| 2/2 [00:04<00:00,  2.31s/it]
100%|██████████| 2/2 [00:05<00:00,  2.75s/it]


Percentage: 65 , loss=154.44005222022534 acc=0.8515499999999999


100%|██████████| 2/2 [00:05<00:00,  2.52s/it]
100%|██████████| 2/2 [00:05<00:00,  2.85s/it]
100%|██████████| 2/2 [00:05<00:00,  2.55s/it]
100%|██████████| 2/2 [00:05<00:00,  2.97s/it]
100%|██████████| 2/2 [00:04<00:00,  2.41s/it]
100%|██████████| 2/2 [00:05<00:00,  2.53s/it]
100%|██████████| 2/2 [00:05<00:00,  2.67s/it]
100%|██████████| 2/2 [00:05<00:00,  2.52s/it]
100%|██████████| 2/2 [00:05<00:00,  2.82s/it]
100%|██████████| 2/2 [00:04<00:00,  2.47s/it]


Percentage: 70 , loss=151.2595176719129 acc=0.8534666666666666


100%|██████████| 2/2 [00:06<00:00,  3.00s/it]
100%|██████████| 2/2 [00:05<00:00,  2.56s/it]
100%|██████████| 2/2 [00:06<00:00,  3.07s/it]
100%|██████████| 2/2 [00:05<00:00,  2.66s/it]
100%|██████████| 2/2 [00:05<00:00,  2.83s/it]
100%|██████████| 2/2 [00:05<00:00,  2.64s/it]
100%|██████████| 2/2 [00:05<00:00,  2.74s/it]
100%|██████████| 2/2 [00:05<00:00,  2.80s/it]
100%|██████████| 2/2 [00:05<00:00,  2.63s/it]
100%|██████████| 2/2 [00:06<00:00,  3.05s/it]


Percentage: 75 , loss=153.90307120382786 acc=0.8513749999999998


100%|██████████| 2/2 [00:05<00:00,  2.75s/it]
100%|██████████| 2/2 [00:06<00:00,  3.25s/it]
100%|██████████| 2/2 [00:05<00:00,  2.83s/it]
100%|██████████| 2/2 [00:06<00:00,  3.29s/it]
100%|██████████| 2/2 [00:05<00:00,  2.80s/it]
100%|██████████| 2/2 [00:06<00:00,  3.22s/it]
100%|██████████| 2/2 [00:05<00:00,  2.75s/it]
100%|██████████| 2/2 [00:06<00:00,  3.32s/it]
100%|██████████| 2/2 [00:05<00:00,  2.81s/it]
100%|██████████| 2/2 [00:06<00:00,  3.18s/it]


Percentage: 80 , loss=149.41098198443652 acc=0.8557083333333333


100%|██████████| 2/2 [00:06<00:00,  3.04s/it]
100%|██████████| 2/2 [00:07<00:00,  3.50s/it]
100%|██████████| 2/2 [00:06<00:00,  3.21s/it]
100%|██████████| 2/2 [00:06<00:00,  3.25s/it]
100%|██████████| 2/2 [00:06<00:00,  3.37s/it]
100%|██████████| 2/2 [00:06<00:00,  3.08s/it]
100%|██████████| 2/2 [00:06<00:00,  3.33s/it]
100%|██████████| 2/2 [00:06<00:00,  3.05s/it]
100%|██████████| 2/2 [00:06<00:00,  3.36s/it]
100%|██████████| 2/2 [00:06<00:00,  3.16s/it]


Percentage: 85 , loss=144.88574964031577 acc=0.8605833333333333


100%|██████████| 2/2 [00:06<00:00,  3.49s/it]
100%|██████████| 2/2 [00:06<00:00,  3.11s/it]
100%|██████████| 2/2 [00:07<00:00,  3.57s/it]
100%|██████████| 2/2 [00:06<00:00,  3.47s/it]
100%|██████████| 2/2 [00:06<00:00,  3.44s/it]
100%|██████████| 2/2 [00:07<00:00,  3.66s/it]
100%|██████████| 2/2 [00:06<00:00,  3.17s/it]
100%|██████████| 2/2 [00:07<00:00,  3.55s/it]
100%|██████████| 2/2 [00:06<00:00,  3.18s/it]
100%|██████████| 2/2 [00:07<00:00,  3.64s/it]


Percentage: 90 , loss=147.3020055092871 acc=0.8577666666666668


100%|██████████| 2/2 [00:07<00:00,  3.53s/it]
100%|██████████| 2/2 [00:07<00:00,  3.55s/it]
100%|██████████| 2/2 [00:07<00:00,  3.80s/it]
100%|██████████| 2/2 [00:06<00:00,  3.38s/it]
100%|██████████| 2/2 [00:07<00:00,  3.67s/it]
100%|██████████| 2/2 [00:06<00:00,  3.27s/it]
100%|██████████| 2/2 [00:07<00:00,  3.85s/it]
100%|██████████| 2/2 [00:07<00:00,  3.67s/it]
100%|██████████| 2/2 [00:06<00:00,  3.30s/it]
100%|██████████| 2/2 [00:07<00:00,  3.78s/it]


Percentage: 95 , loss=146.0119609594345 acc=0.8582583333333333


100%|██████████| 2/2 [00:07<00:00,  3.63s/it]
100%|██████████| 2/2 [00:07<00:00,  3.89s/it]
100%|██████████| 2/2 [00:08<00:00,  4.05s/it]
100%|██████████| 2/2 [00:06<00:00,  3.49s/it]
100%|██████████| 2/2 [00:08<00:00,  4.09s/it]
100%|██████████| 2/2 [00:07<00:00,  3.90s/it]
100%|██████████| 2/2 [00:07<00:00,  3.53s/it]
100%|██████████| 2/2 [00:08<00:00,  4.04s/it]
100%|██████████| 2/2 [00:07<00:00,  3.75s/it]
100%|██████████| 2/2 [00:07<00:00,  3.82s/it]


Percentage: 100 , loss=144.7901761163026 acc=0.8595416666666666


In [None]:
model = SimpleFCN().to(device)
subsample = int(len(train_dataset) * 1 / 100)
mask = np.zeros(len(train_dataset), dtype=bool)
random_indices = np.random.choice(len(train_dataset), subsample, replace=False)
mask[random_indices] = True

subset_loader = get_subset_loader(train_dataset, mask, num_workers=num_workers)
train_model(model, subset_loader)
loss, acc = validate_model(model, val_dataloader)
print(f"loss={loss} acc={acc}")



100%|██████████| 2/2 [00:00<00:00, 23.88it/s]


loss=346.98670893907547 acc=0.6820833333333334
