In [1]:
import torch
from torch.utils.data import TensorDataset, DataLoader
import numpy as np

In [5]:
device=torch.device('cuda:0')

In [3]:
x=torch.randn(1000, 6724)
y=torch.randn(1000, 6724)

x= torch.nn.Sequential(torch.nn.Unflatten(1, (1,82,82)))(x)
print(x.size())

torch.Size([1000, 1, 82, 82])


In [6]:
x=x.to(device)
y=y.to(device)
loader=DataLoader(TensorDataset(x, y), batch_size=200)

In [7]:
import torch.nn as nn
import torch.nn.functional as F

class ConvNet(torch.nn.Module):
    def __init__(self, channel_1, channel_2, kernel_dim):
        super(ConvNet, self).__init__()
        # 1 input image channel, 6 output channels, 5x5 square convolution
        # kernel
        self.conv1 = nn.Conv2d(1, channel_1, kernel_dim)
        self.conv2 = nn.Conv2d(channel_1, channel_2, kernel_dim)
        # an affine operation: y = Wx + b
        self.fc1 = nn.Linear(78*78*3, 1000)  # 78*78 from image dimension
        self.fc2 = nn.Linear(1000, 1000)
        self.fc3 = nn.Linear(1000, 6724)
        
    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = torch.flatten(x, 1) # flatten all dimensions except the batch dimension
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

model = ConvNet(channel_1=3, channel_2=3, kernel_dim=3).to(device)
print(model)

ConvNet(
  (conv1): Conv2d(1, 3, kernel_size=(3, 3), stride=(1, 1))
  (conv2): Conv2d(3, 3, kernel_size=(3, 3), stride=(1, 1))
  (fc1): Linear(in_features=18252, out_features=1000, bias=True)
  (fc2): Linear(in_features=1000, out_features=1000, bias=True)
  (fc3): Linear(in_features=1000, out_features=6724, bias=True)
)


In [8]:
input = torch.randn(1, 1, 82, 82)
input=input.to(device)
out = model(input)
print(out.size())

torch.Size([1, 6724])


In [36]:
input = torch.randn(1, 1, 3, 3)
print(input.size())

torch.Size([1, 1, 3, 3])


In [62]:
input = torch.randn(1, 4)
print(input.size())
input= torch.nn.Sequential(torch.nn.Unflatten(1, ( 1,2,2)))(input)
print(input.size())

torch.Size([1, 4])
torch.Size([1, 1, 2, 2])


In [10]:
optimizer=torch.optim.Adam(model.parameters(), lr=1e-4)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1000, gamma=0.1)

In [11]:
loss_epoch=[]
loss_values = []
loss=1
tol=1e-7
epochs=0

print("Epochs    Loss")

while(loss>tol):
    epochs=epochs+1
    
    for x_batch, y_batch in loader:
        # Forward pass
        y_pred=model(x_batch)
        loss=torch.nn.functional.mse_loss(y_pred, y_batch)
        
        # Backward pass
        loss.backward()
        
        # Update Weights
        optimizer.step()
        optimizer.zero_grad()
    
    loss_epoch.append(epochs)
    loss_values.append(loss.item())
    
    if epochs%1==0:
        print("Epochs: ", epochs, "; Loss: ", loss.item())
        
    loss=loss.item()

print(epochs, "    ", loss.item())

#Plot loss function
from matplotlib import pyplot as plt
plt.plot(loss_epoch, loss_values)
plt.xlabel('epochs')
plt.ylabel('loss')

Epochs    Loss
Epochs:  1 ; Loss:  0.9989970326423645
Epochs:  2 ; Loss:  0.9984937906265259
Epochs:  3 ; Loss:  0.9979026913642883
Epochs:  4 ; Loss:  0.9969338774681091
Epochs:  5 ; Loss:  0.995715856552124
Epochs:  6 ; Loss:  0.9943270087242126
Epochs:  7 ; Loss:  0.9927055835723877
Epochs:  8 ; Loss:  0.9909926652908325
Epochs:  9 ; Loss:  0.9891831874847412
Epochs:  10 ; Loss:  0.9872633218765259
Epochs:  11 ; Loss:  0.9851192235946655
Epochs:  12 ; Loss:  0.9826411604881287
Epochs:  13 ; Loss:  0.9798600077629089
Epochs:  14 ; Loss:  0.9769915342330933
Epochs:  15 ; Loss:  0.973777174949646
Epochs:  16 ; Loss:  0.9705660343170166
Epochs:  17 ; Loss:  0.9670158624649048
Epochs:  18 ; Loss:  0.9633798599243164
Epochs:  19 ; Loss:  0.959646463394165
Epochs:  20 ; Loss:  0.9557541608810425
Epochs:  21 ; Loss:  0.9517802596092224
Epochs:  22 ; Loss:  0.9477421045303345
Epochs:  23 ; Loss:  0.9442567229270935
Epochs:  24 ; Loss:  0.9402888417243958
Epochs:  25 ; Loss:  0.93638670444488

Epochs:  204 ; Loss:  0.6108357310295105
Epochs:  205 ; Loss:  0.6097609996795654
Epochs:  206 ; Loss:  0.6069252490997314
Epochs:  207 ; Loss:  0.6047159433364868
Epochs:  208 ; Loss:  0.6040868163108826
Epochs:  209 ; Loss:  0.6036326885223389
Epochs:  210 ; Loss:  0.6028020977973938
Epochs:  211 ; Loss:  0.6012420654296875
Epochs:  212 ; Loss:  0.5994646549224854
Epochs:  213 ; Loss:  0.5993980765342712
Epochs:  214 ; Loss:  0.6007200479507446
Epochs:  215 ; Loss:  0.6002305150032043
Epochs:  216 ; Loss:  0.5975459218025208
Epochs:  217 ; Loss:  0.5957473516464233
Epochs:  218 ; Loss:  0.5949562191963196
Epochs:  219 ; Loss:  0.5937219262123108
Epochs:  220 ; Loss:  0.592682421207428
Epochs:  221 ; Loss:  0.5916838645935059
Epochs:  222 ; Loss:  0.5905284881591797
Epochs:  223 ; Loss:  0.5895203948020935
Epochs:  224 ; Loss:  0.5884223580360413
Epochs:  225 ; Loss:  0.5873481631278992
Epochs:  226 ; Loss:  0.5864571332931519
Epochs:  227 ; Loss:  0.5855309963226318
Epochs:  228 ; Lo

Epochs:  404 ; Loss:  0.4690741002559662
Epochs:  405 ; Loss:  0.468241810798645
Epochs:  406 ; Loss:  0.4673086404800415
Epochs:  407 ; Loss:  0.4666769504547119
Epochs:  408 ; Loss:  0.4651685059070587
Epochs:  409 ; Loss:  0.46428096294403076
Epochs:  410 ; Loss:  0.4640306532382965
Epochs:  411 ; Loss:  0.46374839544296265
Epochs:  412 ; Loss:  0.4631724953651428
Epochs:  413 ; Loss:  0.46229252219200134
Epochs:  414 ; Loss:  0.46147093176841736
Epochs:  415 ; Loss:  0.4609399139881134
Epochs:  416 ; Loss:  0.46101775765419006
Epochs:  417 ; Loss:  0.46167564392089844
Epochs:  418 ; Loss:  0.46182310581207275
Epochs:  419 ; Loss:  0.46071669459342957
Epochs:  420 ; Loss:  0.4594693183898926
Epochs:  421 ; Loss:  0.4587554633617401
Epochs:  422 ; Loss:  0.4578987956047058
Epochs:  423 ; Loss:  0.4570467472076416
Epochs:  424 ; Loss:  0.4564476013183594
Epochs:  425 ; Loss:  0.456102192401886
Epochs:  426 ; Loss:  0.45625340938568115
Epochs:  427 ; Loss:  0.45643699169158936
Epochs: 

Epochs:  602 ; Loss:  0.3727775514125824
Epochs:  603 ; Loss:  0.37265777587890625
Epochs:  604 ; Loss:  0.37279707193374634
Epochs:  605 ; Loss:  0.37125539779663086
Epochs:  606 ; Loss:  0.3694111108779907
Epochs:  607 ; Loss:  0.3691807687282562
Epochs:  608 ; Loss:  0.36889931559562683
Epochs:  609 ; Loss:  0.36830711364746094
Epochs:  610 ; Loss:  0.36868175864219666
Epochs:  611 ; Loss:  0.36914926767349243
Epochs:  612 ; Loss:  0.36905932426452637
Epochs:  613 ; Loss:  0.3690318167209625
Epochs:  614 ; Loss:  0.36873024702072144
Epochs:  615 ; Loss:  0.3673945367336273
Epochs:  616 ; Loss:  0.3664090633392334
Epochs:  617 ; Loss:  0.3664132058620453
Epochs:  618 ; Loss:  0.3658519983291626
Epochs:  619 ; Loss:  0.3647475838661194
Epochs:  620 ; Loss:  0.3644118010997772
Epochs:  621 ; Loss:  0.36457186937332153
Epochs:  622 ; Loss:  0.36479616165161133
Epochs:  623 ; Loss:  0.36429521441459656
Epochs:  624 ; Loss:  0.3627541661262512
Epochs:  625 ; Loss:  0.3613343834877014
Epoc

KeyboardInterrupt: 