In [1]:
import torch
import numpy as np
import cv2
import glob
import os

device = torch.device('cuda')

In [2]:
device

device(type='cuda')

In [73]:
class Dataset(torch.utils.data.Dataset):
    def __init__(self, path):
        super().__init__()
        self.image_paths = glob.glob(path)
        self.n_samples = len(self.image_paths)
    
    def __getitem__(self, sample_n):
        img_path = self.image_paths[sample_n]
        img = cv2.imread(img_path)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
        img = np.expand_dims(img, 2)
        img = img/255

        label = int(os.path.basename(img_path).split("_")[0])
        return img, label
        
        
    def __len__(self):
        return self.n_samples

In [74]:
train_ds = Dataset("../datasets/mnist/train/*")
# train_dl = torch.utils.data.dataloader(train_ds, batch_size=32)
train_dl = torch.utils.data.DataLoader(
    train_ds,
    batch_size=128,
    shuffle=True,
    num_workers=4
)

In [75]:
val_ds = Dataset("../datasets/mnist/val/*")
# train_dl = torch.utils.data.dataloader(train_ds, batch_size=32)
val_dl = torch.utils.data.DataLoader(
    val_ds,
    batch_size=128,
    shuffle=True,
    num_workers=4
)

In [115]:
class Model(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder = torch.nn.Sequential(
        
            torch.nn.Conv2d(1,32, kernel_size=3, stride=2),
            torch.nn.ReLU(),
            torch.nn.Conv2d(32,64, kernel_size=3, stride=2),
            torch.nn.ReLU(),
            torch.nn.Flatten(),
            torch.nn.Linear(2304,10)
        )
        self.decoder0 = torch.nn.Sequential(
            torch.nn.Linear(10,7*7*32),
            torch.nn.ReLU(),)
        self.decoder1 = torch.nn.Sequential(
            torch.nn.Upsample(scale_factor=2),
            torch.nn.ConvTranspose2d(32, 64, kernel_size=3, padding=1),
            torch.nn.ReLU(),
            torch.nn.Upsample(scale_factor=2),
            torch.nn.ConvTranspose2d(64, 32, kernel_size=3, padding=1),
            torch.nn.ReLU(),
            torch.nn.ConvTranspose2d(32, 1, kernel_size=3, padding=1)
        )

        
    def forward(self, x):
        # print(x.shape)
        out = self.encoder(x)
        # print(out.shape)
        out = self.decoder0(out)
        out = out.reshape(-1, 32, 7 ,7)
        out = self.decoder1(out)
        return out

In [116]:
model = Model().to(device)
optimizer = torch.optim.Adam(model.parameters())
# lossf = torch.nn.CrossEntropyLoss()
lossf= torch.nn.BCEWithLogitsLoss()

In [119]:
epochs = 1

for epoch in range(epochs):
    for img, label in train_dl:
        
        optimizer.zero_grad()     
        
        img = img.float().to(device).permute(0,3,1,2)
        # print()
        # label = label.to(device)
        
        reconstruction = model(img)
        # print(img.shape)
        # print(reconstruction.shape)
        
        loss = lossf(reconstruction, img)
        loss.backward()
        
        optimizer.step()
    accs = []
    for batch in val_dl:

        img, label = batch
        img = img.float().to(device).permute(0,3,1,2)
        label = label.to(device)
        with torch.no_grad():
            reconstruction = model(img)

        predict = reconstruction.argmax(1)
        acc = (predict == reconstruction).float().mean().detach().cpu().numpy()
        accs.append(acc)
    print(np.mean(accs))

0.0
