In [None]:
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
from scipy.io import readsav
from mpl_toolkits.axes_grid1 import make_axes_locatable
import pickle

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision.transforms import ToTensor

In [None]:
def _load_data(filename):
    dat = readsav(filename)
    emission = dat['emission_structure']
    return emission[0]

In [None]:
target_path = Path('tv_images')
label_path = Path('inversion_data')

target_names = [f for f in target_path.glob('*') if f.is_file()]

### Sample Point Data Extraction

In [None]:
file_idx = 0
data_idx = 0
# inverted, radii, elevation, frames, times, vid_frames, vid_times, vid
[inverted,_,_,frames,_,_,_,vid] = _load_data(target_names[file_idx])
# frame, x_location, l_location, r_location, x_intensity, l_intensity, r_intensity
pkl_path = (label_path / target_names[file_idx].stem).with_suffix('.pkl') # target and label have same name stem
pkl_file = open(pkl_path, 'rb')
label_info = pickle.load(pkl_file)
pkl_file.close()
invert_idx = frames.astype(int)

In [None]:
l_point = label_info['l_location'][data_idx]
r_point = label_info['r_location'][data_idx]
print(f"left point = {l_point}")
print(f"right point = {r_point}")
# test thing for github
label_norm = inverted[0].shape[0]
target_norm = 255
print(f"label normalization = {label_norm}")

plt.imshow(inverted[data_idx], cmap = 'plasma')
plt.scatter(l_point[0],l_point[1],c='lime',s=1)
plt.scatter(r_point[0],r_point[1],c='lime',s=1)
plt.show()

In [None]:
target_point = vid[invert_idx[data_idx]] / target_norm
label_point = np.array([label_info['l_location'][data_idx], label_info['r_location'][data_idx]]).ravel() / label_norm
print(f"label = {label_point}")

im_ratio = target_point.shape[1]/target_point.shape[0]
plt.imshow(target_point, vmin=0, vmax=1, cmap = 'plasma')
plt.colorbar(orientation="horizontal",fraction=0.047*im_ratio)
plt.title("target")
plt.show()

### Dataloading Implementation

In [None]:
class TVDataset(Dataset):
    def __init__(self, target_path, label_path, file_name):
        self.target_path = target_path
        self.label_path = label_path
        self.file_name = file_name
        self.pkl_path = (self.label_path / self.file_name.stem).with_suffix('.pkl')
        self.target_norm = 255
        self.label_norm = 201
        
    def __len__(self):
        return len(readsav(self.file_name)['emission_structure'][0][3]) # gets length of frames
    
    def __getitem__(self, idx):
        frame = readsav(self.file_name)['emission_structure'][0][3][idx]
        target = readsav(self.file_name)['emission_structure'][0][7][int(frame)] / self.target_norm
        with open(self.pkl_path, 'rb') as pkl_file:
            label_info = pickle.load(pkl_file)
        label = np.array([label_info['l_location'][idx], label_info['r_location'][idx]]).ravel() / self.label_norm
        return target, label

In [None]:
test_dataset = TVDataset(target_path, label_path, target_names[file_idx])
print(f'label = {test_dataset[0][1]}')
plt.imshow(test_dataset[0][0], vmin=0, vmax=1, cmap = 'plasma')
plt.colorbar(orientation="horizontal",fraction=0.047*im_ratio)
plt.title("target")
plt.show()

In [None]:
dataloader = DataLoader(test_dataset, batch_size=4,
                        shuffle=True, num_workers=0)

In [None]:
for i_batch, sample_batched in enumerate(dataloader):
    print(i_batch, sample_batched[0].size(),
          sample_batched[1].size())
    if i_batch == 3:
        break

In [None]:
class ConvNeuralNet(nn.Module):
	#  Determine what layers and their order in CNN object 
    def __init__(self, num_classes):
        super(ConvNeuralNet, self).__init__()
        self.conv_layer1 = nn.Conv2d(in_channels=3, out_channels=32, kernel_size=3)
        self.conv_layer2 = nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3)
        self.max_pool1 = nn.MaxPool2d(kernel_size = 2, stride = 2)
        
        self.conv_layer3 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3)
        self.conv_layer4 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3)
        self.max_pool2 = nn.MaxPool2d(kernel_size = 2, stride = 2)
        
        self.fc1 = nn.Linear(1600, 4)
    
    # Progresses data across layers    
    def forward(self, x):
        out = self.conv_layer1(x)
        out = self.conv_layer2(out)
        out = self.max_pool1(out)
        
        out = self.conv_layer3(out)
        out = self.conv_layer4(out)
        out = self.max_pool2(out)
                
        out = out.reshape(out.size(0), -1)
        
        out = self.fc1(out)
        return out

In [None]:
model = ConvNeuralNet(num_classes=4)
n_epochs = 10
loss_fn = nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
model.train()