In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Subset
from torchvision import datasets, transforms
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score
import numpy as np


# Block definition (with optional upsampling)
class Block(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, upsample=False):
        super(Block, self).__init__()
        self.upsample = upsample
        self.conv = nn.Conv2d(
            in_channels,
            out_channels,
            kernel_size=kernel_size,
            stride=stride,
            padding=1,  # Ensure output size matches input size
            bias=False
        )
        self.norm = nn.BatchNorm2d(out_channels)
        self.act = nn.ReLU()

    def forward(self, x):
        if self.upsample:
            x = nn.functional.interpolate(x, scale_factor=2, mode="bilinear", align_corners=False)
        x = self.conv(x)
        x = self.norm(x)
        x = self.act(x)
        return x


# AutoEncoder
class AutoEncoder(nn.Module):
    def __init__(self, latent_dim=128):
        super(AutoEncoder, self).__init__()

        # Encoder
        self.encoder = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=3, stride=2, padding=1),  # 28x28 -> 14x14
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1),  # 14x14 -> 7x7
            nn.ReLU(),
            nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),  # 7x7 -> 4x4
            nn.ReLU(),
        )

        # Bottleneck (latent space)
        self.bottleneck = nn.Sequential(
            nn.Flatten(),
            nn.Linear(128 * 4 * 4, latent_dim),  # Embedding into a 128-dimensional space
        )

     
        self.decoder = nn.Sequential(
            Block(128, 64, upsample=True),  # 4x4 -> 8x8
            Block(64, 32, upsample=True),   # 8x8 -> 16x16
            Block(32, 16, upsample=True),   # 16x16 -> 32x32
            nn.Conv2d(16, 1, kernel_size=3, padding=1),  # 32x32 -> 28x28 (restore original size)
            nn.Sigmoid(),
        )

    def forward(self, x):
        x = self.encoder(x)
        x = self.bottleneck(x)
        x = x.unsqueeze(-1).unsqueeze(-1)  # Reshaping for decoder
        x = self.decoder(x)

        # Resize to ensure the output is exactly 28x28
        x = nn.functional.interpolate(x, size=(28, 28), mode="bilinear", align_corners=False)
        return x

    @torch.no_grad()
    def encode(self, x):
        x = self.encoder(x)  # Получаем закодированное представление (размерность [batch_size, 128, 4, 4])
        x = x.flatten(start_dim=1)  # Преобразуем в 1D вектор с размерностью [batch_size, 2048]
        x = self.bottleneck(x)  # Проходим через слой bottleneck, чтобы получить эмбеддинг размером [batch_size, 128]
        return x  # Возвращаем эмбеддинг размерности [batch_size, 128]



# Save embeddings function
def save_embeddings(x_train, y_train, x_valid, y_valid):
    assert x_train.shape[0] == 1000
    assert x_valid.shape[0] == 10000
    assert y_train.shape[0] == 1000
    assert y_valid.shape[0] == 10000

    torch.save(
        {
            'x_train': x_train,
            'y_train': y_train,
            'x_valid': x_valid,
            'y_valid': y_valid
        },
        'embeddings2.pt'
    )


# Main function
def main():
    # Dataset and Dataloader
    transform = transforms.Compose([transforms.ToTensor()])
    train_data = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
    test_data = datasets.MNIST(root='./data', train=False, download=True, transform=transform)

    train_subset = Subset(train_data, range(1000))
    train_loader = DataLoader(train_subset, batch_size=64, shuffle=True)
    test_loader = DataLoader(test_data, batch_size=64, shuffle=False)

    # Model setup
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    autoencoder = AutoEncoder(latent_dim=128).to(device)
    optimizer = optim.Adam(autoencoder.parameters(), lr=1e-3)
    criterion = nn.MSELoss()

    for epoch in range(45):
        # Train AutoEncoder
        autoencoder.train()
        epoch_loss = 0
        for images, _ in train_loader:
            images = images.to(device)
            optimizer.zero_grad()
            outputs = autoencoder(images)
            loss = criterion(outputs, images)
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()
        print(f"Epoch {epoch + 1}, Loss: {epoch_loss / len(train_loader):.4f}")

        # Generate embeddings for training and testing data
        autoencoder.eval()
        train_embeddings, train_labels = [], []
        test_embeddings, test_labels = [], []

        with torch.no_grad():
            for images, labels in train_loader:
                images = images.to(device)
                embeddings = autoencoder.encode(images).cpu().numpy()  # Get the feature maps
                train_embeddings.append(embeddings)  # No need to reshape, as embeddings are already flat
                train_labels.append(labels.numpy())

            for images, labels in test_loader:
                images = images.to(device)
                embeddings = autoencoder.encode(images).cpu().numpy()  # Get the feature maps
                test_embeddings.append(embeddings)  # No need to reshape, as embeddings are already flat
                test_labels.append(labels.numpy())

        train_embeddings = np.vstack(train_embeddings)
        train_labels = np.concatenate(train_labels)
        test_embeddings = np.vstack(test_embeddings)
        test_labels = np.concatenate(test_labels)

        # Train Random Forest
        clf = RandomForestClassifier(random_state=0)  # Hyperparameters for improvement
        clf.fit(train_embeddings, train_labels)

        # Evaluate
        predictions = clf.predict(test_embeddings)
        accuracy = accuracy_score(test_labels, predictions)
        print(f"Epoch {epoch + 1}, Test Accuracy: {accuracy * 100:.2f}%")

        if accuracy > 0.90:
            print("Accuracy exceeded 90%, stopping training.")
            break


    # Save embeddings
    save_embeddings(
        x_train=torch.tensor(train_embeddings, dtype=torch.float16),
        y_train=torch.tensor(train_labels, dtype=torch.int32),
        x_valid=torch.tensor(test_embeddings, dtype=torch.float16),
        y_valid=torch.tensor(test_labels, dtype=torch.int32),
    )


if __name__ == "__main__":
    main()


Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:00<00:00, 16099634.12it/s]


Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 471323.10it/s]


Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:00<00:00, 4480637.84it/s]


Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 7240793.91it/s]


Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw

Epoch 1, Loss: 0.1455
Epoch 1, Test Accuracy: 73.31%
Epoch 2, Loss: 0.0839
Epoch 2, Test Accuracy: 73.26%
Epoch 3, Loss: 0.0701
Epoch 3, Test Accuracy: 74.19%
Epoch 4, Loss: 0.0642
Epoch 4, Test Accuracy: 77.73%
Epoch 5, Loss: 0.0605
Epoch 5, Test Accuracy: 80.59%
Epoch 6, Loss: 0.0574
Epoch 6, Test Accuracy: 82.05%
Epoch 7, Loss: 0.0548
Epoch 7, Test Accuracy: 83.51%
Epoch 8, Loss: 0.0527
Epoch 8, Test Accuracy: 84.45%
Epoch 9, Loss: 0.0509
Epoch 9, Test Accuracy: 85.25%
Epoch 10, Loss: 0.0490
Epoch 10, Test Accuracy: 86.15%
Epoch 11, Loss: 0.0472
Epoch 11, Test Accuracy: 86.62%
Epoch 12, Loss: 0.0462
Epoch 12, Test Accuracy: 86.75%
Epoch 13, Loss: 0.0453
Epoch 13, Test Accuracy: 87.52%
Epoch 14, Loss: 0.0443
Epoch 14, Test Accuracy: 87.76%
Epoch 15, Loss: 0.0434
Epoch 15, Test Accuracy: 87.77%
Epoch 16, Loss: 0.0426
Epoch 16, Test Accuracy: 87.67%
Epoch 17, Loss: 0.0418
Epoch 17, Test Accuracy: 88.44%
Epoch 18