In [None]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from IPython.display import HTML
from pathlib import Path
from scipy.io import readsav
from mpl_toolkits.axes_grid1 import make_axes_locatable
from tqdm.notebook import tqdm
import pickle
import h5py

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch.cuda.amp import GradScaler, autocast
from torchvision.transforms import ToTensor
from torchvision import transforms

In [None]:
def _load_data(filename):
    dat = readsav(filename)
    emission = dat['emission_structure']
    return emission[0]

In [None]:
file_names = ['tv_raw.hdf5', 'tv_crop.hdf5', 'tv_process.hdf5']
file_name = Path(file_names[0])

### Neural Network Implementation

In [None]:
class TVDataset(Dataset):
    def __init__(self, file_path):
        self.file_path = file_path
        self.file = h5py.File(file_path, 'r')
        self.target = self.file['tv_images']
        self.label = self.file['points']

    def __len__(self):
        return self.target.shape[0]

    def __getitem__(self, idx):
        target = np.array(self.target[idx])
        label = np.array(self.label[idx])
        return target, label
    
    def close(self):
        self.file.close()


In [None]:
test_dataset = TVDataset(file_name)

In [None]:
test_dataset.close()

In [None]:
print(f"dataset length = {len(test_dataset)}")
print(f'label = {test_dataset[0][1]}')
im_ratio = test_dataset[0][0].shape[1]/test_dataset[0][0].shape[0]
if file_name.name == 'tv_process.hdf5':
    plt.imshow(test_dataset[0][0], vmin=0, vmax=1, cmap = 'plasma')
else:
    plt.imshow(test_dataset[0][0], vmin=0, vmax=255, cmap = 'plasma')
plt.colorbar(orientation="horizontal",fraction=0.047*im_ratio)
plt.title("target")
plt.show()

In [None]:
target_mean = np.mean(test_dataset[:][0])
target_std = np.std(test_dataset[:][0])
label_mean = np.mean(test_dataset[:][1],axis=0)
label_std = np.std(test_dataset[:][1],axis=0)

In [None]:
print(target_mean, label_mean)

In [None]:
transform_data = transforms.Compose([
    transforms.Normalize(mean=)

In [None]:
batch_size = 3
dataloader = DataLoader(test_dataset, batch_size=batch_size,
                        shuffle=True, num_workers=0)

In [None]:
class ConvNeuralNet(nn.Module):
    def __init__(self):
        super(ConvNeuralNet, self).__init__()
        self.conv1 = nn.Conv2d(1, 16, kernel_size=3, stride=1, padding=1)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1)
        self.fc1 = nn.Linear(32 * 60 * 180, 128)
        self.fc2 = nn.Linear(128, 4)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(x.size(0), -1)  # Flatten the tensor
        x = self.fc1(x)
        x = self.fc2(x)  # Add sigmoid activation to the final layer
        return x
    
model = ConvNeuralNet()

In [None]:
n_epochs = 1
loss_fn = nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = model.to(device)
print(device)

In [None]:
from torch.cuda.amp import GradScaler, autocast

model.train()
scaler = GradScaler()  # for mixed precision training
n_mini = len(dataloader)
accumulation_steps = 10  # change this to fit your GPU memory
for epoch in range(n_epochs):
    running_loss = 0.0
    optimizer.zero_grad()  # reset gradients at the start of each epoch
    for i, data in enumerate(dataloader, 0):

        inputs, labels = data
        inputs = inputs.float().to(device)
        labels = labels.float().to(device)

        with autocast():  # for mixed precision training
            outputs = model(inputs)
            loss = loss_fn(outputs, labels)

        scaler.scale(loss).backward()

        if (i+1) % accumulation_steps == 0:  # perform an optimizer step every accumulation_steps mini-batches
            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad()

        # print statistics
        running_loss += loss.item()
        if i % 20 == 19:    # print every 20 mini-batches
            print('[%d, %5d] loss: %.3f' %
                  (epoch + 1, i + 1, running_loss / 20))
            running_loss = 0.0

print('Finished Training')

In [None]:
# create a new dataloader with the sample dataset
sample_dataset = TVDataset(target_path, label_path)
sample_dataloader = DataLoader(sample_dataset, batch_size=batch_size,
                        shuffle=True, num_workers=0)

# set the model to evaluation mode
model.eval()

# initialize the loss and accuracy
test_loss = 0
correct = 0
total = 0

# iterate over the sample dataloader
with torch.no_grad():
    for data in sample_dataloader:
        inputs, labels = data
        inputs = inputs.float().to(device)
        labels = labels.float().to(device)
        outputs = model(inputs)
        print(outputs)
        print(labels)
        test_loss += loss_fn(outputs, labels).item()
        print(test_loss)
        total += labels.size(0)
        break

# calculate the loss and accuracy
test_loss /= len(sample_dataloader.dataset)


In [None]:
test_dataset = TVDataset(target_path, label_path)
# print(f'label = {test_dataset[0][1]}')
tensor_test = test_dataset[0][0].unsqueeze(0)
predicted = []
with torch.no_grad():
    for i in tqdm(range(len(test_dataset))):
        data = test_dataset[i][0].unsqueeze(0)
        predicted.append(model(data.float().to(device)).cpu().numpy()[0])

### Visualize Evaluation

In [None]:
# Precompute data
data = [(predicted[num], test_dataset[num][1].numpy()) for num in tqdm(range(len(predicted)))]

In [None]:
i = 0
fig, ax = plt.subplots()


# Create scatter plots
scat_pred = ax.scatter([], [], c='lime', label='predicted')
scat_actual = ax.scatter([], [], c='red', label='actual')

def update(num):
    pred, actual = data[num]
    x1, y1, x2, y2 = pred
    a1, b1, a2, b2 = actual
    scat_pred.set_offsets(np.c_[[x1, x2], [y1, y2]])
    scat_actual.set_offsets(np.c_[[a1, a2], [b1, b2]])
    return scat_pred, scat_actual

ax.legend()
ax.set_xlim([0,1])
ax.set_ylim([0,1])

ani = animation.FuncAnimation(fig, update, frames=tqdm(range(len(predicted))), interval=30)
HTML(ani.to_jshtml())

In [None]:
i = 0
fig, ax = plt.subplots()

# Create scatter plots
scat_pred = ax.scatter([], [], c='lime', label='predicted')
scat_actual = ax.scatter([], [], c='red', label='actual')

def update(num):
    pred, actual = data[num]
    x1, y1, x2, y2 = pred
    a1, b1, a2, b2 = actual
    scat_pred.set_offsets(np.c_[[x1, x2], [y1, y2]])
    scat_actual.set_offsets(np.c_[[a1, a2], [b1, b2]])
    return scat_pred, scat_actual

ax.legend()
ax.set_xlim([0,1])
ax.set_ylim([0,1])

pbar = tqdm(total=len(predicted))
ani = animation.FuncAnimation(fig, update, frames=range(len(predicted)), interval=30, blit=True, repeat=False)

def update_progress(current, total):
    pbar.update()

ani.save(Path('./tmp/animation.mp4'), writer='ffmpeg', fps=60, progress_callback=update_progress)
pbar.close()

HTML(ani.to_jshtml())