In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, ConcatDataset
from torchvision import datasets, transforms

import matplotlib.pyplot as plt

In [None]:
def augmentation_transform():
    augmentation_transform = transforms.Compose([
        transforms.RandomRotation(10),
        transforms.RandomAffine(degrees=0, translate=(0.1, 0.1)),
        transforms.RandomResizedCrop(28, scale=(0.9, 1.1)),
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,))
    ])

    return augmentation_transform

In [None]:
original_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=augmentation_transform())

augmented_datasets = []
num_augmented_samples = 100000 - len(original_dataset) 

while len(augmented_datasets) * len(original_dataset) < num_augmented_samples:
    augmented_datasets.append(original_dataset)

full_dataset = ConcatDataset([original_dataset] + augmented_datasets)
print(f"Total samples in the augmented dataset: {len(full_dataset)}")

In [None]:
def plot_images(dataset, num_images=5):
    fig, axes = plt.subplots(1, num_images, figsize=(10, 2))
    for i in range(num_images):
        img, label = dataset[i]
        axes[i].imshow(img.squeeze(), cmap='gray')
        axes[i].set_title(f'Label: {label}')
        axes[i].axis('off')
    plt.show()

# Визуализация аугментированных изображений
plot_images(full_dataset)

In [None]:
batch_size = 64
train_loader = DataLoader(full_dataset, batch_size=batch_size, shuffle=True)

assert len(full_dataset) % batch_size == 0, "Dataset size is not divisible by batch_size"