In [1]:
import pickle
import numpy as np
import torch
import torch.nn as nn
from matplotlib import pyplot as plt

In [2]:
lf = 3
pkl_file = open(f'../data/cat{lf}.pkl', 'rb')

data = pickle.load(pkl_file)
pkl_file.close()
data.shape

(794, 90000, 3)

In [3]:
image_chop = 10
data = data.reshape(int(data.shape[0] * image_chop), int(data.shape[1] / image_chop), 3)
data.shape

(7940, 9000, 3)

In [4]:
device = torch.device('cuda:0')

In [5]:
x = torch.tensor(data, dtype=torch.float32, device=device)

In [6]:
x.shape

torch.Size([7940, 9000, 3])

In [7]:
# params
N = x.shape[0]
D_in = x.shape[1]
H1 = int(x.shape[1]*0.8)
H2 = int(H1 * 0.7)
H3 = int(H2 * 0.6)
D_out = int(H3 * 0.5)

learning_rate = 0.001
batch_size = image_chop * 4
epochs = 200

# Regularisierung
weight_decay=0.001

In [8]:
# Neural Network
class autoencoder(nn.Module):
    def __init__(self):
        super(autoencoder, self).__init__()
        
        self.encoder = torch.nn.Sequential(
            nn.Linear(D_in, H1),
            nn.ReLU(),
            nn.Linear(H1, H2),
            nn.ReLU(),
            nn.Linear(H2, H3),
            nn.ReLU(),
            nn.Linear(H3, D_out)
        )

        self.decoder = torch.nn.Sequential(
            nn.Linear(D_out, H3),
            nn.ReLU(),
            nn.Linear(H3, H2),
            nn.ReLU(),
            nn.Linear(H2, H1),
            nn.ReLU(),
            nn.Linear(H1, D_in),
            nn.Tanh()
        )
        
    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x

In [9]:
model = autoencoder()
model.cuda()

autoencoder(
  (encoder): Sequential(
    (0): Linear(in_features=9000, out_features=7200, bias=True)
    (1): ReLU()
    (2): Linear(in_features=7200, out_features=5040, bias=True)
    (3): ReLU()
    (4): Linear(in_features=5040, out_features=3024, bias=True)
    (5): ReLU()
    (6): Linear(in_features=3024, out_features=1512, bias=True)
  )
  (decoder): Sequential(
    (0): Linear(in_features=1512, out_features=3024, bias=True)
    (1): ReLU()
    (2): Linear(in_features=3024, out_features=5040, bias=True)
    (3): ReLU()
    (4): Linear(in_features=5040, out_features=7200, bias=True)
    (5): ReLU()
    (6): Linear(in_features=7200, out_features=9000, bias=True)
    (7): Tanh()
  )
)

In [10]:
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
criterion = nn.MSELoss()

In [11]:
# Train
loss_hist = []
for t in range(epochs):
    for batch in range(0, int(N/batch_size)):
        # Berechne den Batch
        batch_x = x[batch * batch_size : (batch + 1) * batch_size, :].transpose(1, 2)
        
        # Berechne die Vorhersage (foward step)
        outputs = model.forward(batch_x)
        
        # Berechne den Fehler
        loss = criterion(outputs, batch_x)
        
        # Berechne die Gradienten und Aktualisiere die Gewichte (backward step)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
    # Berechne den Fehler (Ausgabe des Fehlers alle x Iterationen)
    if t % 2 == 0:
        loss_hist.append(loss.item())
        print(t, loss.item())

0 0.06895352154970169
2 0.0663219690322876
4 0.06515029817819595
6 0.06485523283481598
8 0.06455817073583603
10 0.06401797384023666
12 0.06455269455909729
14 0.06534720957279205
16 0.06583148241043091
18 0.06481381505727768
20 0.06479747593402863
22 0.0644020065665245
24 0.0738038495182991
26 0.0650826245546341
28 0.06541869044303894
30 0.06525280326604843
32 0.06466757506132126
34 0.064797542989254
36 0.06481243669986725
38 0.06633490324020386
40 0.06509104371070862
42 0.06513296812772751
44 0.8250378370285034
46 0.5847207903862
48 0.5921646952629089
50 0.541808545589447
52 0.5080893039703369
54 0.4616861343383789
56 0.474700391292572
58 0.44034600257873535
60 0.4342859983444214
62 0.44449126720428467
64 0.07655840367078781
66 0.06462063640356064
68 0.06545939296483994
70 0.06528254598379135
72 0.0646962895989418
74 0.0645107626914978
76 0.06454300880432129
78 0.06515780091285706
80 0.0650234967470169
82 0.064806267619133
84 0.06452348083257675
86 0.06450057774782181
88 0.063804641366

KeyboardInterrupt: 

In [None]:
plt.plot(loss_hist)