In [3]:
import torch
from torch import nn
from torch.utils.data import TensorDataset, Dataset, DataLoader
from torch.optim import Adam
device = 'cuda' if torch.cuda.is_available else 'cpu'
from torchvision import datasets
import numpy as np
import matplotlib.pyplot as plt

In [4]:
X_train = torch.tensor([[[[1,2,3,4],[2,3,4,5],[5,6,7,8],[1,3,4,5]]],
                        [[[-1,2,3,-4],[2,-3,4,5],[-5,6,-7,8],[-1,-3,-4,-5]]
                        ]]).to(device).float()
X_train /= 8

y_train = torch.tensor([0,1]).to(device).float()

In [5]:
X_train.size()

torch.Size([2, 1, 4, 4])

In [6]:
def get_model():
    model = nn.Sequential(
        nn.Conv2d(1,1,kernel_size=3),
        nn.MaxPool2d(2),
        nn.ReLU(),
        nn.Flatten(),
        nn.Linear(1,1),
        nn.Sigmoid(),
    ).to(device)

    # Binary Cross Entropy
    criterion = nn.BCELoss()
    optimizer = Adam(model.parameters(), lr = 0.0001)
    return model, criterion, optimizer

In [7]:
from torchsummary import summary
model,criterion,optimizer = get_model()
summary(model,X_train)

Layer (type:depth-idx)                   Output Shape              Param #
├─Conv2d: 1-1                            [-1, 1, 2, 2]             10
├─MaxPool2d: 1-2                         [-1, 1, 1, 1]             --
├─ReLU: 1-3                              [-1, 1, 1, 1]             --
├─Flatten: 1-4                           [-1, 1]                   --
├─Linear: 1-5                            [-1, 1]                   2
├─Sigmoid: 1-6                           [-1, 1]                   --
Total params: 12
Trainable params: 12
Non-trainable params: 0
Total mult-adds (M): 0.00
Input size (MB): 0.00
Forward/backward pass size (MB): 0.00
Params size (MB): 0.00
Estimated Total Size (MB): 0.00


Layer (type:depth-idx)                   Output Shape              Param #
├─Conv2d: 1-1                            [-1, 1, 2, 2]             10
├─MaxPool2d: 1-2                         [-1, 1, 1, 1]             --
├─ReLU: 1-3                              [-1, 1, 1, 1]             --
├─Flatten: 1-4                           [-1, 1]                   --
├─Linear: 1-5                            [-1, 1]                   2
├─Sigmoid: 1-6                           [-1, 1]                   --
Total params: 12
Trainable params: 12
Non-trainable params: 0
Total mult-adds (M): 0.00
Input size (MB): 0.00
Forward/backward pass size (MB): 0.00
Params size (MB): 0.00
Estimated Total Size (MB): 0.00

In [14]:
train_loader = torch.utils.data.DataLoader(TensorDataset(X_train,y_train))

for x,y in train_loader:
    print(x.size())
    print(y.size())
    break

torch.Size([1, 1, 4, 4])
torch.Size([1])


In [15]:
def train(model,dataloader, opt, criterion):
    model.train()
    for image, label in dataloader:
        
        image = image.to(device)
        label = label.to(device)

        opt.zero_grad()
        prediction = model(image)
        batch_loss = criterion(prediction.squeeze(0),label)

        batch_loss.backward()
        opt.step()

    return batch_loss

In [16]:
for epoch in range(2000):
    train(model,train_loader,optimizer,criterion)

    if epoch % 500 == 0:
        print("Epoch : {}\nLoss : {}".format(epoch,train(model,train_loader,optimizer,criterion)))

Epoch : 0
Loss : 0.7582429647445679
Epoch : 500
Loss : 0.7549789547920227
Epoch : 1000
Loss : 0.7518818378448486
Epoch : 1500
Loss : 0.7489427924156189


In [17]:
model(X_train[:1])

tensor([[0.4742]], device='cuda:0', grad_fn=<SigmoidBackward>)

In [19]:
list(model.children())

[Conv2d(1, 1, kernel_size=(3, 3), stride=(1, 1)),
 MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False),
 ReLU(),
 Flatten(start_dim=1, end_dim=-1),
 Linear(in_features=1, out_features=1, bias=True),
 Sigmoid()]

In [26]:
(cnn_w, cnn_b),(lin_w,lin_b) = [(layer.weight.data,layer.bias.data) for layer in list(model.children()) if hasattr(layer,'weight')]
print(cnn_w,cnn_b)

tensor([[[[ 0.3265, -0.0143, -0.0972],
          [-0.1796, -0.2313,  0.1110],
          [-0.1781, -0.0529,  0.0165]]]], device='cuda:0') tensor([-0.0755], device='cuda:0')


In [27]:
print(lin_w, lin_b)

tensor([[-0.5482]], device='cuda:0') tensor([-0.1034], device='cuda:0')
