In [1]:
from imports.ExtractContactCases import ExtractContactCases, cases_dict

ex = ExtractContactCases('contact_extraction1')
samples = ex.load()

In [2]:
import numpy as np
import torch

import torch_geometric.transforms as T
from torch_geometric.nn.pool import radius_graph, knn_graph
from torch_geometric.data import Data
from torch_geometric import seed_everything

seed_everything(0)

im_height=260
im_width=346

def make_graph(events, case):
    
    coords = events[:, 0:2] - np.array([157, 122])
    coords = torch.tensor(coords.astype(np.float32))
    
    coord1, coord2 = coords.T
    
    ts = (events[:, 2] - events[0, 2]).astype(np.float32)
    ts = ts*5e-6

    coord3 = torch.tensor(ts)
    pos = torch.stack((coord1 , coord2 , coord3)).T

    r = torch.norm(coords, dim=-1)
    theta = torch.atan2(coords[:,1], coords[:,0])
    
    feature = torch.stack((r, theta, coord3))
    feature = feature.T / torch.max(feature,dim=-1)[0]

   # feature = torch.hstack((
    #        feature.T / torch.max(feature,dim=-1)[0]
     #       torch.tensor((2*events[:, 3] - 1).astype(np.float32)).reshape(-1, 1)
      #      ))


    edge_index = knn_graph(pos, 16)
    row, col = edge_index
    edge_attr = torch.abs(r[row] - r[col])
    
    y = torch.tensor(np.array(cases_dict[case], dtype=np.float32)).reshape(1, -1)

    data = Data(x=feature, edge_index=edge_index, pos=pos, y=y, edge_attr=edge_attr)
    
    return data


In [3]:
class MaxPooling(torch.nn.Module):

    def __init__(self, size, transform = None):
        super(MaxPooling, self).__init__()
        self.voxel_size = size
        self.transform = transform

    def forward(self, x, pos, batch = None, edge_index = None, return_data_obj = True):
        assert edge_index is not None, "edge_index must not be None"

        cluster = voxel_grid(pos[:, :2], batch=batch, size=self.voxel_size)
        data = Data(x=x, pos=pos, edge_index=edge_index, batch=batch)
        data = max_pool(cluster, data=data, transform=self.transform)  # transform for new edge attributes
        if return_data_obj:
            return data
        else:
            return data.x, data.pos, getattr(data, "batch"), data.edge_index, data.edge_attr

    def __repr__(self):
        return f"{self.__class__.__name__}(voxel_size={self.voxel_size})"

In [4]:
from torch_geometric.nn import GCNConv, global_mean_pool
from torch.nn.functional import elu, dropout

class Model(torch.nn.Module):
    def __init__(self):
        super(Model, self).__init__()

        self.conv1 = GCNConv(3, 64)
        self.bn1 = torch.nn.BatchNorm1d(64)
        self.conv2 = GCNConv(64, 128)
        self.conv3 = GCNConv(128, 256)
    
        self.fc1 = torch.nn.Linear(256, 64)
        self.fc2 = torch.nn.Linear(64, 2)
    

    def forward(self, data):
        
        conv1_out = elu(self.conv1(data.x, data.edge_index, edge_weight=data.edge_attr))
        conv2_out = elu(self.conv2(conv1_out, data.edge_index, edge_weight=data.edge_attr))
        conv3_out = elu(self.conv3(conv2_out, data.edge_index, edge_weight=data.edge_attr))

        #conv1_out = elu(conv1_out)
        x = global_mean_pool(conv3_out, data.batch)  # [batch_size, hidden_channels]

        x = elu(self.fc1(x))
        x = dropout(x, p=0.5, training=self.training)
        x = self.fc2(x)

        
        return x

In [5]:
model = Model().cuda()

In [6]:
from sklearn.model_selection import train_test_split


sample_idx = list(samples.keys())
cases = [str(samples[s_idx]['case']) for s_idx in sample_idx]

train_idx, val_test_idx = train_test_split(sample_idx, test_size=1-0.6, random_state=0, stratify=cases) #fixed across extractions

cases = [str(samples[s_idx]['case']) for s_idx in val_test_idx]
val_idx, test_idx = train_test_split(val_test_idx,  test_size=0.5, random_state=0, stratify=cases) #fixed across extractions


        
def val_gen():
    for sample_idx in val_idx:
        data = make_graph(np.array(samples[sample_idx]['events']), samples[sample_idx]['case'])
        yield data
        
        
def test_gen():
    for sample_idx in test_idx:
        data = make_graph(np.array(samples[sample_idx]['events']), samples[sample_idx]['case'])
        yield data
        



In [7]:
ev_arr = np.array(samples['sample_100']['events'])
period = ev_arr[-1, 2] - ev_arr[0, 2]

period*1e-9*0.85

0.0425663

In [8]:
def time_scale(sample_idx):
    for scale in [0.9, 0.95, 1, 1.05, 1.1]:
        ev_arr = np.array(samples[sample_idx]['events'])
        case = samples[sample_idx]['case']
        if scale == 1:
            yield make_graph(ev_arr, case)
        else:
            ev_arr[:, 2] = scale*ev_arr[:, 2]
            yield make_graph(ev_arr, case)

def train_gen():
    for sample_idx in train_idx:
        for data in time_scale(sample_idx):
            yield data
            
from torch_geometric.loader import DataLoader

#data_list = [Data(...), ..., Data(...)]
loader = DataLoader([i for i in train_gen()], batch_size=5, shuffle=True)

In [9]:
d = next(train_gen())
d.x

tensor([[1.0000, 0.5121, 0.0000],
        [1.0000, 0.5121, 0.0032],
        [1.0000, 0.5121, 0.0064],
        ...,
        [0.2252, 0.2120, 0.9987],
        [0.9046, 0.3863, 0.9992],
        [0.6603, 0.8879, 1.0000]])

In [10]:
loss_func = torch.nn.L1Loss()
n_epochs = 150
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=0.0005)
#scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', min_lr=1e-6, patience=25)

def validate():
    loss = 0
    for i, data in enumerate(val_gen()):      
        data = data.to('cuda')
        end_point = model(data)

        loss += loss_func(end_point, data.y).detach().item()
    loss /= len(val_idx)
    return loss

def test():
    loss = 0
    for i, data in enumerate(test_gen()):      
        data = data.to('cuda')
        end_point = model(data)

        loss += loss_func(end_point, data.y).detach().item()
    loss /= len(val_idx)
    return loss

In [11]:
from tqdm.auto import trange, tqdm
train_losses = []
val_losses = []
lrs = []



for epoch in trange(n_epochs, desc='training', unit='epoch'):
    
    if epoch == 20:
        for param_group in optimizer.param_groups:
            param_group['lr'] = 0.001
    
    if epoch == 80:
        for param_group in optimizer.param_groups:
            param_group['lr'] = 0.0001

    if epoch == 100:
        for param_group in optimizer.param_groups:
            param_group['lr'] = 0.00001
    
    
    #bunny(epoch)
    epoch_loss = 0
    lr = optimizer.param_groups[0]['lr']
    lrs.append(lr)
    val_loss = torch.inf
    with tqdm(loader, unit="batch", total=len(train_idx)) as tepoch:
        for i, data in enumerate(tepoch):
            tepoch.set_description(f"Epoch {epoch}")
            with torch.autograd.detect_anomaly():
                data = data.to('cuda')
                optimizer.zero_grad()
                end_point = model(data)
                loss = loss_func(end_point, data.y)
                loss.backward()
                optimizer.step()
                lr = optimizer.param_groups[0]['lr']

                epoch_loss += loss.detach().item()

                tepoch.set_postfix({
                    'train_loss': epoch_loss / (i + 1), 
                    'train_loss_degrees': epoch_loss / (i + 1) * 180/np.pi, 
                    'val_loss': val_losses[epoch - 1] if epoch > 0 else 'na',
                    'val_loss_degrees': val_losses[epoch - 1] * 180/np.pi if epoch > 0 else 'na',
                    'lr': lr
                    })

        #scheduler.step(val_loss)
        epoch_loss /= len(train_idx)
        val_loss = validate()
        tepoch.set_postfix({'train_loss': epoch_loss, 'val_loss': val_loss})
        train_losses.append(epoch_loss)
        val_losses.append(val_loss)
        


training:   0%|          | 0/150 [00:00<?, ?epoch/s]

  0%|          | 0/386 [00:00<?, ?batch/s]

  with torch.autograd.detect_anomaly():


  0%|          | 0/386 [00:00<?, ?batch/s]

  0%|          | 0/386 [00:00<?, ?batch/s]

  0%|          | 0/386 [00:00<?, ?batch/s]

  0%|          | 0/386 [00:00<?, ?batch/s]

  0%|          | 0/386 [00:00<?, ?batch/s]

  0%|          | 0/386 [00:00<?, ?batch/s]

  0%|          | 0/386 [00:00<?, ?batch/s]

  0%|          | 0/386 [00:00<?, ?batch/s]

  0%|          | 0/386 [00:00<?, ?batch/s]

  0%|          | 0/386 [00:00<?, ?batch/s]

  0%|          | 0/386 [00:00<?, ?batch/s]

  0%|          | 0/386 [00:00<?, ?batch/s]

  0%|          | 0/386 [00:00<?, ?batch/s]

  0%|          | 0/386 [00:00<?, ?batch/s]

  0%|          | 0/386 [00:00<?, ?batch/s]

  0%|          | 0/386 [00:00<?, ?batch/s]

  0%|          | 0/386 [00:00<?, ?batch/s]

  0%|          | 0/386 [00:00<?, ?batch/s]

  0%|          | 0/386 [00:00<?, ?batch/s]

  0%|          | 0/386 [00:00<?, ?batch/s]

  0%|          | 0/386 [00:00<?, ?batch/s]

  0%|          | 0/386 [00:00<?, ?batch/s]

  0%|          | 0/386 [00:00<?, ?batch/s]

  0%|          | 0/386 [00:00<?, ?batch/s]

  0%|          | 0/386 [00:00<?, ?batch/s]

  0%|          | 0/386 [00:00<?, ?batch/s]

  0%|          | 0/386 [00:00<?, ?batch/s]

  0%|          | 0/386 [00:00<?, ?batch/s]

  0%|          | 0/386 [00:00<?, ?batch/s]

  0%|          | 0/386 [00:00<?, ?batch/s]

  0%|          | 0/386 [00:00<?, ?batch/s]

  0%|          | 0/386 [00:00<?, ?batch/s]

  0%|          | 0/386 [00:00<?, ?batch/s]

  0%|          | 0/386 [00:00<?, ?batch/s]

  0%|          | 0/386 [00:00<?, ?batch/s]

  0%|          | 0/386 [00:00<?, ?batch/s]

  0%|          | 0/386 [00:00<?, ?batch/s]

  0%|          | 0/386 [00:00<?, ?batch/s]

  0%|          | 0/386 [00:00<?, ?batch/s]

  0%|          | 0/386 [00:00<?, ?batch/s]

  0%|          | 0/386 [00:00<?, ?batch/s]

  0%|          | 0/386 [00:00<?, ?batch/s]

  0%|          | 0/386 [00:00<?, ?batch/s]

  0%|          | 0/386 [00:00<?, ?batch/s]

  0%|          | 0/386 [00:00<?, ?batch/s]

  0%|          | 0/386 [00:00<?, ?batch/s]

  0%|          | 0/386 [00:00<?, ?batch/s]

  0%|          | 0/386 [00:00<?, ?batch/s]

  0%|          | 0/386 [00:00<?, ?batch/s]

  0%|          | 0/386 [00:00<?, ?batch/s]

  0%|          | 0/386 [00:00<?, ?batch/s]

  0%|          | 0/386 [00:00<?, ?batch/s]

  0%|          | 0/386 [00:00<?, ?batch/s]

  0%|          | 0/386 [00:00<?, ?batch/s]

  0%|          | 0/386 [00:00<?, ?batch/s]

  0%|          | 0/386 [00:00<?, ?batch/s]

  0%|          | 0/386 [00:00<?, ?batch/s]

  0%|          | 0/386 [00:00<?, ?batch/s]

  0%|          | 0/386 [00:00<?, ?batch/s]

  0%|          | 0/386 [00:00<?, ?batch/s]

  0%|          | 0/386 [00:00<?, ?batch/s]

  0%|          | 0/386 [00:00<?, ?batch/s]

  0%|          | 0/386 [00:00<?, ?batch/s]

  0%|          | 0/386 [00:00<?, ?batch/s]

  0%|          | 0/386 [00:00<?, ?batch/s]

  0%|          | 0/386 [00:00<?, ?batch/s]

  0%|          | 0/386 [00:00<?, ?batch/s]

  0%|          | 0/386 [00:00<?, ?batch/s]

  0%|          | 0/386 [00:00<?, ?batch/s]

  0%|          | 0/386 [00:00<?, ?batch/s]

  0%|          | 0/386 [00:00<?, ?batch/s]

  0%|          | 0/386 [00:00<?, ?batch/s]

  0%|          | 0/386 [00:00<?, ?batch/s]

  0%|          | 0/386 [00:00<?, ?batch/s]

  0%|          | 0/386 [00:00<?, ?batch/s]

  0%|          | 0/386 [00:00<?, ?batch/s]

  0%|          | 0/386 [00:00<?, ?batch/s]

  0%|          | 0/386 [00:00<?, ?batch/s]

  0%|          | 0/386 [00:00<?, ?batch/s]

  0%|          | 0/386 [00:00<?, ?batch/s]

  0%|          | 0/386 [00:00<?, ?batch/s]

  0%|          | 0/386 [00:00<?, ?batch/s]

  0%|          | 0/386 [00:00<?, ?batch/s]

  0%|          | 0/386 [00:00<?, ?batch/s]

  0%|          | 0/386 [00:00<?, ?batch/s]

  0%|          | 0/386 [00:00<?, ?batch/s]

  0%|          | 0/386 [00:00<?, ?batch/s]

  0%|          | 0/386 [00:00<?, ?batch/s]

  0%|          | 0/386 [00:00<?, ?batch/s]

  0%|          | 0/386 [00:00<?, ?batch/s]

  0%|          | 0/386 [00:00<?, ?batch/s]

  0%|          | 0/386 [00:00<?, ?batch/s]

  0%|          | 0/386 [00:00<?, ?batch/s]

  0%|          | 0/386 [00:00<?, ?batch/s]

  0%|          | 0/386 [00:00<?, ?batch/s]

  0%|          | 0/386 [00:00<?, ?batch/s]

  0%|          | 0/386 [00:00<?, ?batch/s]

  0%|          | 0/386 [00:00<?, ?batch/s]

  0%|          | 0/386 [00:00<?, ?batch/s]

  0%|          | 0/386 [00:00<?, ?batch/s]

  0%|          | 0/386 [00:00<?, ?batch/s]

  0%|          | 0/386 [00:00<?, ?batch/s]

  0%|          | 0/386 [00:00<?, ?batch/s]

  0%|          | 0/386 [00:00<?, ?batch/s]

  0%|          | 0/386 [00:00<?, ?batch/s]

  0%|          | 0/386 [00:00<?, ?batch/s]

  0%|          | 0/386 [00:00<?, ?batch/s]

  0%|          | 0/386 [00:00<?, ?batch/s]

  0%|          | 0/386 [00:00<?, ?batch/s]

  0%|          | 0/386 [00:00<?, ?batch/s]

  0%|          | 0/386 [00:00<?, ?batch/s]

  0%|          | 0/386 [00:00<?, ?batch/s]

  0%|          | 0/386 [00:00<?, ?batch/s]

  0%|          | 0/386 [00:00<?, ?batch/s]

  0%|          | 0/386 [00:00<?, ?batch/s]

  0%|          | 0/386 [00:00<?, ?batch/s]

  0%|          | 0/386 [00:00<?, ?batch/s]

  0%|          | 0/386 [00:00<?, ?batch/s]

  0%|          | 0/386 [00:00<?, ?batch/s]

  0%|          | 0/386 [00:00<?, ?batch/s]

  0%|          | 0/386 [00:00<?, ?batch/s]

  0%|          | 0/386 [00:00<?, ?batch/s]

  0%|          | 0/386 [00:00<?, ?batch/s]

  0%|          | 0/386 [00:00<?, ?batch/s]

  0%|          | 0/386 [00:00<?, ?batch/s]

  0%|          | 0/386 [00:00<?, ?batch/s]

  0%|          | 0/386 [00:00<?, ?batch/s]

  0%|          | 0/386 [00:00<?, ?batch/s]

  0%|          | 0/386 [00:00<?, ?batch/s]

  0%|          | 0/386 [00:00<?, ?batch/s]

  0%|          | 0/386 [00:00<?, ?batch/s]

  0%|          | 0/386 [00:00<?, ?batch/s]

  0%|          | 0/386 [00:00<?, ?batch/s]

  0%|          | 0/386 [00:00<?, ?batch/s]

  0%|          | 0/386 [00:00<?, ?batch/s]

  0%|          | 0/386 [00:00<?, ?batch/s]

  0%|          | 0/386 [00:00<?, ?batch/s]

  0%|          | 0/386 [00:00<?, ?batch/s]

  0%|          | 0/386 [00:00<?, ?batch/s]

  0%|          | 0/386 [00:00<?, ?batch/s]

  0%|          | 0/386 [00:00<?, ?batch/s]

  0%|          | 0/386 [00:00<?, ?batch/s]

  0%|          | 0/386 [00:00<?, ?batch/s]

  0%|          | 0/386 [00:00<?, ?batch/s]

  0%|          | 0/386 [00:00<?, ?batch/s]

  0%|          | 0/386 [00:00<?, ?batch/s]

  0%|          | 0/386 [00:00<?, ?batch/s]

  0%|          | 0/386 [00:00<?, ?batch/s]

  0%|          | 0/386 [00:00<?, ?batch/s]