In [1]:
import torch
import torchvision
from tqdm import tqdm

In [3]:
transform = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize((0),(1))
])
dataset_train = torchvision.datasets.FashionMNIST("./",train=True,download=True,transform=transform)
dataset_test = torchvision.datasets.FashionMNIST("./",train=False,download=True,transform=transform)
train_data = torch.utils.data.DataLoader(dataset_train,batch_size=32,shuffle=True)
test_data = torch.utils.data.DataLoader(dataset_test,batch_size=32)

In [4]:
class Model(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.fully_connect1 = torch.nn.Linear(28*28,512)
        self.fully_connect2 = torch.nn.Linear(512,128)
        self.fully_connect3 = torch.nn.Linear(128,10)
        
    def forward(self,x):
        x = x.reshape((x.shape[0],28*28))
        x = self.fully_connect1(x)
        x = torch.relu(x)
        x = torch.dropout(x,0.2,train=True)
        x = self.fully_connect2(x)
        x = torch.relu(x)
        x = torch.dropout(x,0.5,train=True)
        x = self.fully_connect3(x)
        x = torch.softmax(x,dim=1)
        
        return x

In [5]:
device = torch.device("cpu")
model = Model()
model.to(device)
model.train = True

In [6]:
optimizer = torch.optim.Adam(model.parameters(),lr=0.001)
loss_function = torch.nn.CrossEntropyLoss()

In [7]:
def accuracy(predict,label):
    _,pred_max = torch.max(predict,1)
    acc = torch.sum(pred_max==label,dtype=torch.float64) / len(predict)
    return acc

In [8]:
for epoch in range(20):
    train_loss = 0.0
    train_acc = 0.0
    test_loss = 0.0
    test_acc = 0.0
    for image, label in tqdm(train_data):
        image.to(device)
        label.to(device)
        optimizer.zero_grad()
        predict = model(image)
        loss = loss_function(predict,label)
        loss.backward()
        optimizer.step()
        
        train_loss += loss
        train_acc += accuracy(predict,label)
        
    for image, label in test_data:
        image.to(device)
        label.to(device)
        predict = model(image)
        loss = loss_function(predict,label)
        
        test_loss += loss
        test_acc += accuracy(predict,label)
        
    total_train_loss = train_loss / len(train_data)
    total_train_acc = train_acc / len(train_data)
    total_test_loss = test_loss / len(test_data)
    total_test_acc = test_acc / len(test_data)
    print(f"Epoch: {epoch+1}, Accuracy: {total_train_acc}, Loss: {total_train_loss}")
    print(f"Epoch: {epoch+1}, Val_Accuracy: {total_test_acc}, Val_Loss: {total_test_loss}")

100%|██████████| 1875/1875 [00:20<00:00, 90.30it/s]


Epoch: 1, Accuracy: 0.7338333333333333, Loss: 1.7303436994552612
Epoch: 1, Val_Accuracy: 0.7785543130990416, Val_Loss: 1.6831238269805908


100%|██████████| 1875/1875 [00:22<00:00, 84.58it/s]


Epoch: 2, Accuracy: 0.8141333333333334, Loss: 1.6470433473587036
Epoch: 2, Val_Accuracy: 0.8075079872204473, Val_Loss: 1.6526777744293213


100%|██████████| 1875/1875 [00:22<00:00, 82.96it/s]


Epoch: 3, Accuracy: 0.8270333333333333, Loss: 1.6338852643966675
Epoch: 3, Val_Accuracy: 0.8092052715654952, Val_Loss: 1.6514874696731567


100%|██████████| 1875/1875 [00:23<00:00, 78.35it/s]


Epoch: 4, Accuracy: 0.83255, Loss: 1.628200888633728
Epoch: 4, Val_Accuracy: 0.8172923322683706, Val_Loss: 1.6433879137039185


100%|██████████| 1875/1875 [00:25<00:00, 74.72it/s]


Epoch: 5, Accuracy: 0.8380666666666666, Loss: 1.6227941513061523
Epoch: 5, Val_Accuracy: 0.832667731629393, Val_Loss: 1.6283513307571411


100%|██████████| 1875/1875 [00:25<00:00, 72.40it/s]


Epoch: 6, Accuracy: 0.83895, Loss: 1.6218496561050415
Epoch: 6, Val_Accuracy: 0.8348642172523961, Val_Loss: 1.625993251800537


100%|██████████| 1875/1875 [00:26<00:00, 70.35it/s]


Epoch: 7, Accuracy: 0.8356, Loss: 1.6251881122589111
Epoch: 7, Val_Accuracy: 0.8324680511182109, Val_Loss: 1.6278795003890991


100%|██████████| 1875/1875 [00:26<00:00, 69.95it/s]


Epoch: 8, Accuracy: 0.8396166666666667, Loss: 1.6210665702819824
Epoch: 8, Val_Accuracy: 0.8249800319488818, Val_Loss: 1.635846734046936


100%|██████████| 1875/1875 [00:27<00:00, 69.10it/s]


Epoch: 9, Accuracy: 0.8398166666666667, Loss: 1.620877981185913
Epoch: 9, Val_Accuracy: 0.8147963258785943, Val_Loss: 1.6454881429672241


100%|██████████| 1875/1875 [00:27<00:00, 68.52it/s]


Epoch: 10, Accuracy: 0.84095, Loss: 1.6197147369384766
Epoch: 10, Val_Accuracy: 0.8357627795527156, Val_Loss: 1.6251050233840942


100%|██████████| 1875/1875 [00:27<00:00, 67.05it/s]


Epoch: 11, Accuracy: 0.8398666666666667, Loss: 1.6208328008651733
Epoch: 11, Val_Accuracy: 0.8353634185303515, Val_Loss: 1.625266432762146


100%|██████████| 1875/1875 [00:27<00:00, 68.20it/s]


Epoch: 12, Accuracy: 0.8431, Loss: 1.6177085638046265
Epoch: 12, Val_Accuracy: 0.8110023961661342, Val_Loss: 1.6501234769821167


100%|██████████| 1875/1875 [00:27<00:00, 68.41it/s]


Epoch: 13, Accuracy: 0.8410166666666666, Loss: 1.6198278665542603
Epoch: 13, Val_Accuracy: 0.8281749201277955, Val_Loss: 1.6325684785842896


100%|██████████| 1875/1875 [00:26<00:00, 70.50it/s]


Epoch: 14, Accuracy: 0.8373333333333334, Loss: 1.6236129999160767
Epoch: 14, Val_Accuracy: 0.8371605431309904, Val_Loss: 1.6238802671432495


100%|██████████| 1875/1875 [00:27<00:00, 67.97it/s]


Epoch: 15, Accuracy: 0.84165, Loss: 1.6192939281463623
Epoch: 15, Val_Accuracy: 0.8321685303514377, Val_Loss: 1.6285158395767212


100%|██████████| 1875/1875 [00:27<00:00, 68.69it/s]


Epoch: 16, Accuracy: 0.8416, Loss: 1.619328498840332
Epoch: 16, Val_Accuracy: 0.8274760383386581, Val_Loss: 1.633333683013916


100%|██████████| 1875/1875 [00:27<00:00, 67.43it/s]


Epoch: 17, Accuracy: 0.8433833333333334, Loss: 1.617506980895996
Epoch: 17, Val_Accuracy: 0.8197883386581469, Val_Loss: 1.641162633895874


100%|██████████| 1875/1875 [00:27<00:00, 68.78it/s]


Epoch: 18, Accuracy: 0.8407666666666667, Loss: 1.620300531387329
Epoch: 18, Val_Accuracy: 0.819888178913738, Val_Loss: 1.6413416862487793


100%|██████████| 1875/1875 [00:27<00:00, 67.50it/s]


Epoch: 19, Accuracy: 0.84355, Loss: 1.617353916168213
Epoch: 19, Val_Accuracy: 0.827276357827476, Val_Loss: 1.6338437795639038


100%|██████████| 1875/1875 [00:27<00:00, 66.97it/s]


Epoch: 20, Accuracy: 0.8395, Loss: 1.6215580701828003
Epoch: 20, Val_Accuracy: 0.8258785942492013, Val_Loss: 1.6352193355560303


In [9]:
torch.save(model.state_dict(),"weights.pth")