In [None]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from IPython.display import HTML
from pathlib import Path
from scipy.io import readsav
from mpl_toolkits.axes_grid1 import make_axes_locatable
import pickle

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision.transforms import ToTensor

In [None]:
def _load_data(filename):
    dat = readsav(filename)
    emission = dat['emission_structure']
    return emission[0]

In [None]:
target_path = Path('tv_images')
label_path = Path('inversion_data')

target_names = [f for f in target_path.glob('*') if f.is_file()]

### Sample Point Data Extraction

In [None]:
file_idx = 0
data_idx = 0
# inverted, radii, elevation, frames, times, vid_frames, vid_times, vid
[inverted,_,_,frames,_,_,_,vid] = _load_data(target_names[file_idx])
# frame, x_location, l_location, r_location, x_intensity, l_intensity, r_intensity
pkl_path = (label_path / target_names[file_idx].stem).with_suffix('.pkl') # target and label have same name stem
pkl_file = open(pkl_path, 'rb')
label_info = pickle.load(pkl_file)
pkl_file.close()
invert_idx = frames.astype(int)

In [None]:
l_point = label_info['l_location'][data_idx]
r_point = label_info['r_location'][data_idx]
print(f"left point = {l_point}")
print(f"right point = {r_point}")
# test thing for github
label_norm = inverted[0].shape[0]
target_norm = 255
print(f"label normalization = {label_norm}")

plt.imshow(inverted[data_idx], cmap = 'plasma')
plt.scatter(l_point[0],l_point[1],c='lime',s=1)
plt.scatter(r_point[0],r_point[1],c='lime',s=1)
plt.show()

In [None]:
target_point = vid[invert_idx[data_idx]] / target_norm
label_point = np.array([label_info['l_location'][data_idx], label_info['r_location'][data_idx]]).ravel() / label_norm
print(f"label = {label_point}")

im_ratio = target_point.shape[1]/target_point.shape[0]
plt.imshow(target_point, vmin=0, vmax=1, cmap = 'plasma')
plt.colorbar(orientation="horizontal",fraction=0.047*im_ratio)
plt.title("target")
plt.show()

### Neural Network Implementation

In [None]:
class TVDataset(Dataset):
    def __init__(self, target_path, label_path, file_name):
        self.target_path = target_path
        self.label_path = label_path
        self.file_name = file_name
        self.pkl_path = (self.label_path / self.file_name.stem).with_suffix('.pkl')
        self.target_norm = 255
        self.label_norm = 201
        
    def __len__(self):
        return len(readsav(self.file_name)['emission_structure'][0][3]) # gets length of frames
    
    def __getitem__(self, idx):
        frame = readsav(self.file_name)['emission_structure'][0][3][idx]
        target = readsav(self.file_name)['emission_structure'][0][7][int(frame)] / self.target_norm
        target = np.array([target])
        with open(self.pkl_path, 'rb') as pkl_file:
            label_info = pickle.load(pkl_file)
        label = np.array([label_info['l_location'][idx], label_info['r_location'][idx]]).ravel() / self.label_norm
        target_tensor = torch.from_numpy(target)
        label_tensor = torch.from_numpy(label)
        return target_tensor, label_tensor

Need to do this with all videos now

In [None]:
test_dataset = TVDataset(target_path, label_path, target_names[file_idx])
print(f'label = {test_dataset[0][1]}')
plt.imshow(test_dataset[0][0][0], vmin=0, vmax=1, cmap = 'plasma')
plt.colorbar(orientation="horizontal",fraction=0.047*im_ratio)
plt.title("target")
plt.show()

In [None]:
batch_size = 3
dataloader = DataLoader(test_dataset, batch_size=batch_size,
                        shuffle=True, num_workers=0)

In [None]:
for i_batch, sample_batched in enumerate(dataloader):
    print(i_batch, sample_batched[0].size(),
          sample_batched[1].size())
    if i_batch == 3:
        break

In [None]:
class ConvNeuralNet(nn.Module):
    def __init__(self):
        super(ConvNeuralNet, self).__init__()
        self.conv1 = nn.Conv2d(1, 16, kernel_size=3, stride=1, padding=1)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1)
        self.fc1 = nn.Linear(32 * 60 * 180, 128)
        self.fc2 = nn.Linear(128, 4)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(x.size(0), -1)  # Flatten the tensor
        x = self.fc1(x)
        x = self.fc2(x)  # Add sigmoid activation to the final layer
        return x
    
model = ConvNeuralNet()

In [None]:
n_epochs = 5
loss_fn = nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = model.to(device)
print(device)

In [None]:
model.train()
n_mini = len(dataloader)
for epoch in range(n_epochs):
    running_loss = 0.0
    for i, data in enumerate(dataloader, 0):

        inputs, labels = data
        inputs = inputs.float().to(device)
        labels = labels.float().to(device)
        
        optimizer.zero_grad()

        outputs = model(inputs)
        loss = loss_fn(outputs, labels)
        loss.backward()
        optimizer.step()

        # print statistics
        running_loss += loss.item()
        if i % 20== 19:    # print every 2000 mini-batches
            print('[%d, %5d] loss: %.3f' %
                  (epoch + 1, i + 1, running_loss / 20))
            running_loss = 0.0

print('Finished Training')

In [None]:
# create a new dataloader with the sample dataset
sample_dataset = TVDataset(target_path, label_path, target_names[file_idx])
sample_dataloader = DataLoader(sample_dataset, batch_size=batch_size,
                        shuffle=True, num_workers=0)

# set the model to evaluation mode
model.eval()

# initialize the loss and accuracy
test_loss = 0
correct = 0
total = 0

# iterate over the sample dataloader
with torch.no_grad():
    for data in sample_dataloader:
        inputs, labels = data
        inputs = inputs.float().to(device)
        labels = labels.float().to(device)
        outputs = model(inputs)
        print(outputs)
        print(labels)
        test_loss += loss_fn(outputs, labels).item()
        print(test_loss)
        total += labels.size(0)
        break

# calculate the loss and accuracy
test_loss /= len(sample_dataloader.dataset)


In [None]:
test_dataset = TVDataset(target_path, label_path, target_names[file_idx])
# print(f'label = {test_dataset[0][1]}')
tensor_test = test_dataset[0][0].unsqueeze(0)
predicted = []
with torch.no_grad():
    for i in range(len(test_dataset)):
        data = test_dataset[i][0].unsqueeze(0)
        predicted.append(model(data.float().to(device)).cpu().numpy()[0])

### Visualize Evaluation

In [None]:
i = 0
fig, ax = plt.subplots()

def update(num):
    ax.clear()
    x1, y1, x2, y2 = predicted[num]
    a1, b1, a2, b2 = test_dataset[num][1].numpy()
    ax.scatter(x1, y1, c='lime', label= 'predicted')
    ax.scatter(x2, y2, c='lime')
    ax.scatter(a1, b1, c='red', label='actual')
    ax.scatter(a2, b2, c='red')
    ax.legend()
    ax.set_xlim([0,1])
    ax.set_ylim([0,1])

ani = animation.FuncAnimation(fig, update, frames=range(len(predicted)), interval=30)
HTML(ani.to_jshtml())