In [10]:
import torch

# See if GPU is available
print(torch.cuda.is_available())

# Print Detail of GPU
print(torch.cuda.get_device_properties(0))

True
_CudaDeviceProperties(name='NVIDIA GeForce GTX 1050', major=6, minor=1, total_memory=3014MB, multi_processor_count=6)


In [5]:
import numpy as np

# Set random seed for reproducibility
np.random.seed(42)

# Initialize input
input_image = np.random.rand(8, 8, 1)

# Initialize weights
w1 = np.random.rand(3, 3, 1, 2)  # Encoder conv1
w2 = np.random.rand(3, 3, 2, 4)  # Encoder conv2
w3 = np.random.rand(3, 3, 4, 2)  # Decoder conv1
w4 = np.random.rand(3, 3, 3, 1)  # Decoder conv2

def print_step(step_name, data):
    print(f"\n{step_name}:")
    # print(data)
    print(f"Shape: {data.shape}")

def conv2d(input, weight):
    return np.pad(input, ((1, 1), (1, 1), (0, 0)))[1:-1, 1:-1]

def pool2d(input):
    return input[::2, ::2]

def upsample2d(input):
    return np.repeat(np.repeat(input, 2, axis=0), 2, axis=1)

# Forward pass
print_step("Input", input_image)

# Encoder
conv1 = conv2d(input_image, w1)
print_step("Encoder Conv1", conv1)

pool1 = pool2d(conv1)
print_step("Encoder Pool1", pool1)

conv2 = conv2d(pool1, w2)
print_step("Encoder Conv2", conv2)

pool2 = pool2d(conv2)
print_step("Encoder Pool2 (Bottleneck)", pool2)

# Decoder
up1 = upsample2d(pool2)
print_step("Decoder Upsample1", up1)

concat1 = np.concatenate([up1, conv2], axis=2)
print_step("Decoder Concat1", concat1)

conv3 = conv2d(concat1, w3)
print_step("Decoder Conv1", conv3)

up2 = upsample2d(conv3)
print_step("Decoder Upsample2", up2)

concat2 = np.concatenate([up2, conv1], axis=2)
print_step("Decoder Concat2", concat2)

conv4 = conv2d(concat2, w4)
print_step("Decoder Conv2 (Output)", conv4)


Input:
Shape: (8, 8, 1)

Encoder Conv1:
Shape: (8, 8, 1)

Encoder Pool1:
Shape: (4, 4, 1)

Encoder Conv2:
Shape: (4, 4, 1)

Encoder Pool2 (Bottleneck):
Shape: (2, 2, 1)

Decoder Upsample1:
Shape: (4, 4, 1)

Decoder Concat1:
Shape: (4, 4, 2)

Decoder Conv1:
Shape: (4, 4, 2)

Decoder Upsample2:
Shape: (8, 8, 2)

Decoder Concat2:
Shape: (8, 8, 3)

Decoder Conv2 (Output):
Shape: (8, 8, 3)
