In [155]:
import tensorflow as tf
import matplotlib.pyplot as plt
import torch
from tqdm import tqdm

In [156]:
fashion_mnist = tf.keras.datasets.fashion_mnist.load_data()
(X_train_full, y_train_full), (X_test, y_test) = fashion_mnist
X_train, y_train = X_train_full[:-10000], y_train_full[:-10000]
X_valid, y_valid = X_train_full[-10000:], y_train_full[-10000:]
X_train, X_valid, X_test = X_train/255., X_valid/255., X_test/255.
X_train = X_train.reshape([-1, 28, 28, 1])
X_valid = X_valid.reshape([-1, 28, 28, 1])
X_test = X_test.reshape([-1, 28, 28, 1])

In [157]:
X_train = torch.tensor(X_train, dtype=torch.float32)
X_test = torch.tensor(X_test, dtype=torch.float32)
X_val = torch.tensor(X_valid, dtype=torch.float32)

y_train = torch.tensor(y_train, dtype=torch.int64)
y_valid = torch.tensor(y_valid, dtype=torch.int64)
y_test = torch.tensor(y_test, dtype=torch.int64)


In [158]:
from torch import nn
from torch.nn import functional as F

In [159]:
if torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')

print(device)

cpu


In [160]:
X_train = X_train.to(device)
X_test = X_test.to(device)
X_val = X_val.to(device)

y_train = y_train.to(device)
y_test = y_test.to(device)
y_valid = y_valid.to(device)

In [161]:
class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.flatten = nn.Flatten()

        self.sequent = nn.Sequential(
            nn.Linear(28*28, 60),
            nn.ReLU(),
            nn.Linear(60, 60),
            nn.ReLU(),
            nn.Linear(60, 60),
            nn.ReLU(),
            nn.Linear(60, 10)
        )

        self.sequent.apply(self.__init_weights)
    
    def forward(self, x):
        x = self.flatten(x)
        x = self.sequent(x)
        return x
    
    def __init_weights(self, m):
        if isinstance(m, nn.Linear):
            torch.nn.init.kaiming_normal_(m.weight)
    

In [162]:
torch.manual_seed(42)
model = Model()
model.to(device)

Model(
  (flatten): Flatten(start_dim=1, end_dim=-1)
  (sequent): Sequential(
    (0): Linear(in_features=784, out_features=60, bias=True)
    (1): ReLU()
    (2): Linear(in_features=60, out_features=60, bias=True)
    (3): ReLU()
    (4): Linear(in_features=60, out_features=60, bias=True)
    (5): ReLU()
    (6): Linear(in_features=60, out_features=10, bias=True)
  )
)

In [163]:
train_dataloader = torch.utils.data.DataLoader(list(zip(X_train, y_train)), batch_size=32, shuffle=True)
test_dataloader = torch.utils.data.DataLoader(list(zip(X_test, y_test)), batch_size=32, shuffle=True)
val_dataloader = torch.utils.data.DataLoader(list(zip(X_val, y_valid)), batch_size=32, shuffle=False)

In [164]:
optimizer = torch.optim.AdamW(model.parameters(), lr=0.01)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=5, factor=0.5)

min_val_loss = float('inf')
early_stop = 0

for i in tqdm(range(30)):


    model.train()
    acc = 0
    total_loss = 0
    for brach, (X_batch, y_batch) in enumerate(train_dataloader):
        y_pred = model(X_batch)
        loss = F.cross_entropy(y_pred, y_batch)
        total_loss += loss
        acc += (torch.argmax(y_pred, dim=1) == y_batch).sum().item()
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    acc /= len(y_train)
    total_loss /= len(train_dataloader)
    model.eval()
    val_loss = 0
    val_acc = 0
    with torch.no_grad():
        for X_batch, y_batch in val_dataloader:
            y_pred = model(X_batch)
            val_loss += F.cross_entropy(y_pred, y_batch).item()
            val_acc += (torch.argmax(y_pred, dim=1) == y_batch).sum().item()
    
    val_loss /= len(val_dataloader)

    print(f"Epoch {i+1} - Accuracy: {acc:.4f} - Loss: {total_loss:.4f} - Val Loss: {val_loss:.4f}, Val accuracy: {val_acc/len(y_valid):.4f}")
    if val_loss < min_val_loss:
        min_val_loss = val_loss
        early_stop = 0
    else:
        early_stop += 1
        if early_stop == 10:
            break

    scheduler.step(val_loss)
        



    

  3%|▎         | 1/30 [00:03<01:54,  3.94s/it]

Epoch 1 - Accuracy: 0.7926 - Loss: 0.5878 - Val Loss: 0.4813, Val accuracy: 0.8303


  7%|▋         | 2/30 [00:07<01:49,  3.90s/it]

Epoch 2 - Accuracy: 0.8329 - Loss: 0.4713 - Val Loss: 0.4447, Val accuracy: 0.8436


 10%|█         | 3/30 [00:11<01:48,  4.02s/it]

Epoch 3 - Accuracy: 0.8404 - Loss: 0.4530 - Val Loss: 0.4672, Val accuracy: 0.8313


 13%|█▎        | 4/30 [00:15<01:44,  4.01s/it]

Epoch 4 - Accuracy: 0.8448 - Loss: 0.4450 - Val Loss: 0.4967, Val accuracy: 0.8349


 17%|█▋        | 5/30 [00:19<01:39,  3.98s/it]

Epoch 5 - Accuracy: 0.8479 - Loss: 0.4359 - Val Loss: 0.5018, Val accuracy: 0.8282


 20%|██        | 6/30 [00:25<01:46,  4.45s/it]

Epoch 6 - Accuracy: 0.8498 - Loss: 0.4212 - Val Loss: 0.4591, Val accuracy: 0.8404


 23%|██▎       | 7/30 [00:29<01:43,  4.49s/it]

Epoch 7 - Accuracy: 0.8486 - Loss: 0.4358 - Val Loss: 0.6158, Val accuracy: 0.8194


 27%|██▋       | 8/30 [00:34<01:38,  4.49s/it]

Epoch 8 - Accuracy: 0.8507 - Loss: 0.4230 - Val Loss: 0.4761, Val accuracy: 0.8290


 30%|███       | 9/30 [00:39<01:38,  4.68s/it]

Epoch 9 - Accuracy: 0.8695 - Loss: 0.3638 - Val Loss: 0.3852, Val accuracy: 0.8637


 33%|███▎      | 10/30 [00:44<01:33,  4.67s/it]

Epoch 10 - Accuracy: 0.8704 - Loss: 0.3584 - Val Loss: 0.3813, Val accuracy: 0.8637


 37%|███▋      | 11/30 [00:48<01:29,  4.72s/it]

Epoch 11 - Accuracy: 0.8734 - Loss: 0.3511 - Val Loss: 0.3925, Val accuracy: 0.8600


 40%|████      | 12/30 [00:54<01:29,  4.99s/it]

Epoch 12 - Accuracy: 0.8731 - Loss: 0.3503 - Val Loss: 0.3756, Val accuracy: 0.8670


 43%|████▎     | 13/30 [00:58<01:21,  4.81s/it]

Epoch 13 - Accuracy: 0.8737 - Loss: 0.3448 - Val Loss: 0.3832, Val accuracy: 0.8618


 47%|████▋     | 14/30 [01:03<01:15,  4.69s/it]

Epoch 14 - Accuracy: 0.8741 - Loss: 0.3456 - Val Loss: 0.4350, Val accuracy: 0.8485


 50%|█████     | 15/30 [01:08<01:12,  4.83s/it]

Epoch 15 - Accuracy: 0.8748 - Loss: 0.3420 - Val Loss: 0.4247, Val accuracy: 0.8529


 53%|█████▎    | 16/30 [01:13<01:07,  4.83s/it]

Epoch 16 - Accuracy: 0.8760 - Loss: 0.3410 - Val Loss: 0.4106, Val accuracy: 0.8602


 57%|█████▋    | 17/30 [01:17<01:00,  4.63s/it]

Epoch 17 - Accuracy: 0.8769 - Loss: 0.3404 - Val Loss: 0.3985, Val accuracy: 0.8557


 60%|██████    | 18/30 [01:21<00:53,  4.46s/it]

Epoch 18 - Accuracy: 0.8773 - Loss: 0.3357 - Val Loss: 0.3736, Val accuracy: 0.8659


 63%|██████▎   | 19/30 [01:25<00:48,  4.38s/it]

Epoch 19 - Accuracy: 0.8774 - Loss: 0.3360 - Val Loss: 0.4178, Val accuracy: 0.8552


 67%|██████▋   | 20/30 [01:29<00:42,  4.30s/it]

Epoch 20 - Accuracy: 0.8775 - Loss: 0.3334 - Val Loss: 0.4120, Val accuracy: 0.8643


 70%|███████   | 21/30 [01:33<00:37,  4.21s/it]

Epoch 21 - Accuracy: 0.8778 - Loss: 0.3346 - Val Loss: 0.3848, Val accuracy: 0.8626


 73%|███████▎  | 22/30 [01:37<00:33,  4.17s/it]

Epoch 22 - Accuracy: 0.8800 - Loss: 0.3306 - Val Loss: 0.3898, Val accuracy: 0.8621


 77%|███████▋  | 23/30 [01:42<00:29,  4.15s/it]

Epoch 23 - Accuracy: 0.8770 - Loss: 0.3327 - Val Loss: 0.4027, Val accuracy: 0.8618


 80%|████████  | 24/30 [01:46<00:25,  4.18s/it]

Epoch 24 - Accuracy: 0.8785 - Loss: 0.3324 - Val Loss: 0.3922, Val accuracy: 0.8616


 83%|████████▎ | 25/30 [01:51<00:21,  4.39s/it]

Epoch 25 - Accuracy: 0.8898 - Loss: 0.2965 - Val Loss: 0.3735, Val accuracy: 0.8670


 87%|████████▋ | 26/30 [01:55<00:17,  4.43s/it]

Epoch 26 - Accuracy: 0.8911 - Loss: 0.2932 - Val Loss: 0.3691, Val accuracy: 0.8684


 90%|█████████ | 27/30 [01:59<00:13,  4.36s/it]

Epoch 27 - Accuracy: 0.8923 - Loss: 0.2905 - Val Loss: 0.3891, Val accuracy: 0.8676


 93%|█████████▎| 28/30 [02:04<00:08,  4.33s/it]

Epoch 28 - Accuracy: 0.8926 - Loss: 0.2880 - Val Loss: 0.3730, Val accuracy: 0.8665


 97%|█████████▋| 29/30 [02:08<00:04,  4.44s/it]

Epoch 29 - Accuracy: 0.8939 - Loss: 0.2867 - Val Loss: 0.3634, Val accuracy: 0.8726


100%|██████████| 30/30 [02:12<00:00,  4.43s/it]

Epoch 30 - Accuracy: 0.8931 - Loss: 0.2855 - Val Loss: 0.3919, Val accuracy: 0.8660



