In [None]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from IPython.display import HTML
from pathlib import Path
import h5py

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchsummary import summary

In [None]:
file_path = Path('.\hdf5')
file_names = ['tv_raw.hdf5', 'tv_crop.hdf5', 'tv_process.hdf5']
file_name = file_path / file_names[2]

### Normalization

In [None]:
with h5py.File(file_name, 'r') as f:
    target = f['tv_images'][:]
    label = f['points'][:]

target_norm = np.mean(target)
target_std = np.std(target)
label_norm = np.mean(label,axis=0)
label_std = np.std(label,axis=0)

del target
del label

In [None]:
print(f"target_mean: {target_norm:.4f}")
print(f"target_std: {target_std:.4f}")
print(f"label_mean: {label_norm}")
print(f"label_std: {label_std}")

### Dataset and Dataloading

In [None]:
class TVDataset(Dataset):
    def __init__(self, file_path, transform=None):
        self.file_path = file_path
        self.file = h5py.File(file_path, 'r')
        self.target = self.file['tv_images'][:][:, np.newaxis, ...]
        self.label = self.file['points'][:]
        self.transform = transform
        if self.transform:
            self.target, self.label = self.transform(self.target, self.label)
        self.file.close()

    def __len__(self):
        return self.target.shape[0]

    def __getitem__(self, idx):
        target = self.target[idx]
        label = self.label[idx]
        return target, label

class CustomTransform:
    def __init__(self, target_norm, label_norm, target_std, label_std):
        self.target_norm = target_norm
        self.label_norm = label_norm
        self.target_std = target_std
        self.label_std = label_std

    def __call__(self, target, label):
        target = (target - self.target_norm) / self.target_std
        label = (label - self.label_norm) / self.label_std
        return torch.tensor(target).float(), torch.tensor(label).float()

In [None]:
class ConvNet(nn.Module):
    def __init__(self):
        super(ConvNet, self).__init__()
        self.conv1 = nn.Conv2d(1, 16, kernel_size=5, stride=1, padding=1)
        self.pool = nn.MaxPool2d(8, 8)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=5, stride=1, padding=1)
        self.fc1 = nn.Linear(672 , 128)
        self.fc2 = nn.Linear(128, 4)

    def forward(self, x):
        x = self.pool(self.conv1(x))
        x = self.pool(self.conv2(x))
        x = x.view(x.size(0), -1)
        x = self.fc1(x)
        x = self.fc2(x)
        return x

In [None]:
def train_model(model, dataloader, optimizer, criterion, n_epochs, device):
    model.train()
    loss_norm = len(dataloader.dataset)
    for epoch in range(n_epochs):
        running_loss = 0.0
        for inputs, labels in dataloader:
            inputs = inputs.to(device)
            labels = labels.to(device)

            # Zero the parameter gradients
            optimizer.zero_grad()

            # Forward pass
            outputs = model(inputs)
            loss = criterion(outputs, labels)

            # Backward pass and optimization
            loss.backward()
            optimizer.step()

            # Print statistics
            running_loss += loss.item()
        print(f'{epoch + 1} loss: {running_loss / loss_norm:.3}')

    print('Finished Training')
    return model

In [None]:
model = ConvNet()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = ConvNet().to(device)
summary(model, input_size = (1, 240, 480))

In [None]:
transform = CustomTransform(target_norm, label_norm, target_std, label_std)
test_dataset = TVDataset(file_name, transform=transform)

In [None]:
print(len(test_dataset))

In [None]:
batch_size = 256
dataloader = DataLoader(test_dataset, batch_size=batch_size,
                        shuffle=True, num_workers=0)

In [None]:
print(f"dataset length = {len(test_dataset)}")
print(f'label = {test_dataset[0][1]}')
im_ratio = test_dataset[0][0][0].shape[1]/test_dataset[0][0][0].shape[0]
plt.imshow(test_dataset[0][0][0], cmap = 'plasma')
plt.colorbar(orientation="horizontal",fraction=0.047*im_ratio)
plt.title("target")
plt.show()

In [None]:
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
n_epochs = 100

model = train_model(model, dataloader, optimizer, criterion, n_epochs, device)

### Visualize Evaluation

In [None]:
# Assuming you have loaded the new dataset into a variable named `new_dataset`
# and created a dataloader for it named `new_dataloader`
new_dataloader = DataLoader(test_dataset, batch_size=batch_size,
                        shuffle=True, num_workers=0)
n_samples = len(new_dataloader.dataset)
output_size = model.fc2.out_features
predicted = np.zeros((n_samples, output_size))
actual = np.zeros((n_samples, output_size))
model.eval()

# Iterate over the dataloader and predict the output for each input
with torch.no_grad():
    start_index = 0
    for inputs, actual_outputs in new_dataloader:
        inputs = inputs.to(device)
        predicted_outputs = model(inputs)
        end_index = start_index + predicted_outputs.shape[0]
        predicted[start_index:end_index] = predicted_outputs.cpu().numpy()
        actual[start_index:end_index] = actual_outputs.numpy()
        start_index = end_index

In [None]:
renorm_actual = dataloader.dataset[:][1].numpy() * label_std + label_norm
renorm_predicted = predicted * label_std + label_norm

print(renorm_actual)
print(renorm_predicted)

In [None]:
print(np.sum((new_dataloader.dataset[:][1].numpy() - predicted)**2, axis=0))

In [None]:
i = 0
fig, ax = plt.subplots()

# Create scatter plots
scat_pred = ax.scatter([], [], c='lime', label='predicted')
scat_actual = ax.scatter([], [], c='red', label='actual')

def update(num):
    x1, y1, x2, y2 = renorm_predicted[num]
    a1, b1, a2, b2 = renorm_actual[num]
    scat_pred.set_offsets(np.c_[[x1, x2], [y1, y2]])
    scat_actual.set_offsets(np.c_[[a1, a2], [b1, b2]])
    return scat_pred, scat_actual

ax.legend()
ax.set_xlim([0,201])
ax.set_ylim([0,201])

ani = animation.FuncAnimation(fig, update, frames=range(len(predicted)), interval=30, blit=True, repeat=False)

ani.save(Path('./tmp/animation.mp4'), writer='ffmpeg', fps=60)

HTML(ani.to_jshtml())