<a href="https://colab.research.google.com/github/antonKornilov1/some/blob/main/Klass_ruk_chisel_mnist.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import random
import numpy as np
import time

random.seed(0)
np.random.seed(0)
torch.manual_seed(0)
torch.cuda.manual_seed(0)
torch.backends.cudnn.deterministic = True

In [None]:
import torchvision.datasets
MNIST_train = torchvision.datasets.MNIST('./', download=True, train=True)
MNIST_test = torchvision.datasets.MNIST('./', download=True, train=False)

In [None]:
X_train = MNIST_train.train_data
y_train = MNIST_train.train_labels
X_test = MNIST_test.test_data
y_test = MNIST_test.test_labels

In [None]:
X_train.dtype, y_train.dtype

In [None]:
X_train = X_train.float()
X_test = X_test.float()

In [None]:
X_train.shape, X_test.shape

In [None]:
y_train.shape, y_test.shape

In [None]:
import matplotlib.pyplot as plt
plt.imshow(X_train[0, :, :])
plt.show()
print(y_train[0])

In [None]:
X_train = X_train.reshape([-1, 28 * 28])
X_test = X_test.reshape([-1, 28 * 28])

In [None]:
class MNISTNet(torch.nn.Module):
    def __init__(self, n_hidden_neurons):
        super(MNISTNet, self).__init__()

        self.fc1 = torch.nn.Linear(28 * 28, n_hidden_neurons)
        self.ac1 = torch.nn.ReLU()

        self.fc2 = torch.nn.Linear(n_hidden_neurons, n_hidden_neurons)
        self.ac2 = torch.nn.ReLU()

        self.fc3 = torch.nn.Linear(n_hidden_neurons, n_hidden_neurons)
        self.ac3 = torch.nn.ReLU()

        self.fc4 = torch.nn.Linear(n_hidden_neurons, 10)

    def forward(self, x):
        x = self.fc1(x)
        x = self.ac1(x)

        x = self.fc2(x)
        x = self.ac2(x)

        x = self.fc3(x)
        x = self.ac3(x)

        x = self.fc4(x)
        return x

mnist_net = MNISTNet(100)


In [None]:
torch.cuda.is_available()

In [None]:
!nvidia-smi

In [None]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
mnist_net = mnist_net.to(device)
list(mnist_net.parameters())

In [None]:
loss = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(mnist_net.parameters(), lr=1.0e-3)

In [None]:
batch_size = 100

test_accuracy_history = []
test_loss_history = []
train_loss_history = []
epoch_times = []

X_test = X_test.to(device)
y_test = y_test.to(device)
for epoch in range(10000):
    order = np.random.permutation(len(X_train))
    epoch_train_loss = 0
    batch_count = 0

    for start_index in range(0, len(X_train), batch_size):
        optimizer.zero_grad()

        batch_indexes = order[start_index:start_index+batch_size]

        X_batch = X_train[batch_indexes].#to(device)
        y_batch = y_train[batch_indexes].#to(device)

        preds = mnist_net.forward(X_batch)

        loss_value = loss(preds, y_batch)
        loss_value.backward()

        optimizer.step()

        epoch_train_loss += loss_value.item()
        batch_count += 1

    epoch_time = time.time() - start_time
    epoch_times.append(epoch_time)
    train_loss_history.append(avg_train_loss)

    avg_train_loss = epoch_train_loss / batch_count

    test_preds = mnist_net.forward(X_test)
    val_loss_value = loss(test_preds, y_test).data.cpu()
    test_loss_history.append(val_loss_value)

    accuracy = (test_preds.argmax(dim=1) == y_test).float().mean()
    test_accuracy_history.append(accuracy)

    print(f"Epoch {epoch}: "
          f"Train Loss = {avg_train_loss:.4f}, "
          f"Val Loss = {val_loss_value.item():.4f}, "
          f"Accuracy = {accuracy.item():.4f}, "
          f"Time = {epoch_time:.2f}s")

In [None]:
plt.plot(test_accuracy_history)

In [None]:
plt.plot(test_loss_history);

In [None]:
test_accuracy_history = [float(x) for x in test_accuracy_history]
test_loss_history = [float(x) for x in test_loss_history]

plt.figure(figsize=(10, 5))
plt.plot(train_loss_history, label='Train Loss', alpha=0.7)
plt.plot(test_loss_history, label='Validation Loss', alpha=0.7)
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Train and Validation Loss')
plt.legend()
plt.grid(True)
plt.show()