used with [domain_adaptation_1](./domain_adaptation_1.ipynb)

In [66]:
import torch
from torch import nn
import copy
import numpy as np
import torchinfo

## Feature extractor model

In [67]:
# hyperparameters (fixed)
padding_length = 39  # amount of data on each side of each segment for additional info
stride = 11  # size (and therefore spacing) of each segment
input_length = stride + 2*padding_length

In [68]:
def repeat_layers_n_times(layer_list, n):  # this instead of something simpler to be absolutely sure the layers are different objects and not repeating the same one
    new_layer_list = []
    for i in range(n):
        for layer in layer_list:
            new_layer_list.append(copy.deepcopy(layer))
    return new_layer_list

In [69]:
class MMSFeatExtract(nn.Module): #TODO: strided convolution instead of pooling?
    """ 1D CNN Model """
    def __init__(self, mms_num_conv, mms_kp_limit, mms_kernel_size, mms_pool_size, mms_out_channels, mms_learning_rate, mms_dropout_fraction, feat_shape):
        super().__init__()
        # define these all separately because they will get different weights
        # consider smooshing these together into one convolution with in_channels=6. Idk if a good idea
        feat_shape_nobatch = feat_shape[1:]
        self.bx_layers = nn.Sequential(*repeat_layers_n_times([nn.LazyConv1d(mms_out_channels, mms_kernel_size, padding='valid'),
                                           nn.LeakyReLU(), nn.AvgPool1d(mms_pool_size), nn.Dropout(p=mms_dropout_fraction)], mms_num_conv))
        self.by_layers = nn.Sequential(*repeat_layers_n_times([nn.LazyConv1d(mms_out_channels, mms_kernel_size, padding='valid'),
                                           nn.LeakyReLU(), nn.AvgPool1d(mms_pool_size), nn.Dropout(p=mms_dropout_fraction)], mms_num_conv))
        self.bz_layers = nn.Sequential(*repeat_layers_n_times([nn.LazyConv1d(mms_out_channels, mms_kernel_size, padding='valid'),
                                           nn.LeakyReLU(), nn.AvgPool1d(mms_pool_size), nn.Dropout(p=mms_dropout_fraction)], mms_num_conv))
        self.ex_layers = nn.Sequential(*repeat_layers_n_times([nn.LazyConv1d(mms_out_channels, mms_kernel_size, padding='valid'),
                                           nn.LeakyReLU(), nn.AvgPool1d(mms_pool_size), nn.Dropout(p=mms_dropout_fraction)], mms_num_conv))
        self.ey_layers = nn.Sequential(*repeat_layers_n_times([nn.LazyConv1d(mms_out_channels, mms_kernel_size, padding='valid'),
                                           nn.LeakyReLU(), nn.AvgPool1d(mms_pool_size), nn.Dropout(p=mms_dropout_fraction)], mms_num_conv))
        self.ez_layers = nn.Sequential(*repeat_layers_n_times([nn.LazyConv1d(mms_out_channels, mms_kernel_size, padding='valid'),
                                           nn.LeakyReLU(), nn.AvgPool1d(mms_pool_size), nn.Dropout(p=mms_dropout_fraction)], mms_num_conv))
        self.jy_layers = nn.Sequential(*repeat_layers_n_times([nn.LazyConv1d(mms_out_channels, mms_kernel_size, padding='valid'),
                                           nn.LeakyReLU(), nn.AvgPool1d(mms_pool_size), nn.Dropout(p=mms_dropout_fraction)], mms_num_conv))
        
        self.post_merge_layers = nn.Sequential(nn.Conv1d(mms_out_channels, mms_out_channels*2, mms_kernel_size,
                                                         padding='valid'),
                                               nn.LeakyReLU(),
                                               nn.AvgPool1d(mms_pool_size))
        self.resize_features = nn.Sequential(nn.Flatten(), nn.LazyLinear(np.prod(feat_shape_nobatch)), nn.Unflatten(-1, feat_shape_nobatch))
                                               

    def forward(self, bx, by, bz, ex, ey, ez, jy):
        bx_proc = self.bx_layers(bx)
        by_proc = self.by_layers(by)
        bz_proc = self.bz_layers(bz)
        ex_proc = self.ex_layers(ex)
        ey_proc = self.ey_layers(ey)
        ez_proc = self.ez_layers(ez)
        jy_proc = self.jy_layers(jy)
        combined = (bx_proc + by_proc + bz_proc + ex_proc + ey_proc + ez_proc + jy_proc)/6.
        mms_features = self.post_merge_layers(combined)
        features = self.resize_features(mms_features)
        
        return features

## Discriminator part

In [70]:
class Discriminator(nn.Module):  # TODO: use leaky ReLU instead? advised for dcgan
    """ Based on the Tzeng et al. model """
    def __init__(self, discrim_width, discrim_length):
        super().__init__()        
        self.layers = nn.Sequential(*repeat_layers_n_times([nn.LazyLinear(discrim_width), nn.LeakyReLU()], discrim_length))
        self.domain_label = nn.Sequential(nn.Linear(discrim_width, 1))
        
    def forward(self,features):
        batch_size = features.shape[0]
        layers_out = self.layers(features.view(batch_size,-1))
        domain_pred_logit = self.domain_label(layers_out).view(-1)
        
        return domain_pred_logit

## Training loop

In [71]:
def train_adda(dl_source, dl_target, feat_extract_source, feat_extract_target, discriminator,
               loss_fn_disc, loss_fn_feat, optimizer_disc, optimizer_feat, iter_source=None):
    # note feat_extract_source is a special boye that returns a dictionary that we need to get a value from
    # backpropagate feat extractor (target) w/ GAN loss
    # backpropagate discriminator w/ cross entropy loss (log? am confuse)
    inter_iteration_stuff = {}  # whatever I need to escape this function (losses over time, the current iterator, etc.)
    feat_extract_source.eval()
    batch_size = dl_source.batch_size  # the length of a tensordataset is the batch size (shared first dim)
    y_s = torch.zeros(batch_size, device=device, dtype=dtype)  # labels. target = 1, source = 0.
    y_t = torch.ones(batch_size, device=device, dtype=dtype)
    
    # let's iterate enough that the mms dataset is completely used. 
    samples_source = len(dl_source.dataset)
    samples_target = len(dl_target.dataset)
    total_batches =samples_target//batch_size
    print(f"training on {samples_source} source samples and {samples_target} target samples")
    # iterate
    if iter_source is None:  # make the source iterator if it doesn't already exist
        print("Making source iterator")  # shouldn't ever fire with current setup / needs from the code
        iter_source = iter(dl_source)
    iter_target = iter(dl_target)
    for batch in range(total_batches):
        ds_source, iter_source = iter_or_restart_dl(dl_source, iter_source)
        ds_target, iter_target = iter_or_restart_dl(dl_target, iter_target)
        loss_disc = []
        loss_feat = []
        # unpack values
        _, _, bx_s, by_s, bz_s, ex_s, ey_s, ez_s, jy_s, _, _, _ = ds_source
        bx_t, by_t, bz_t, ex_t, ey_t, ez_t, _, jy_t, _, _ = ds_target
        # calculate features, add labels
        feat_source = feat_extract_source(bx_s, by_s, bz_s, ex_s, ey_s, ez_s, jy_s)["features"].detach()  # don't calc gradient
        feat_target = feat_extract_target(bx_t, by_t, bz_t, ex_t, ey_t, ez_t, jy_t)
        
        feat_extract_target.train()
        discriminator.train()
        
        # train the discriminator
        optimizer_disc.zero_grad()
            # on source
        pred_source = discriminator(feat_source)
        lossD_source = loss_fn_disc(pred_source,y_s)
        lossD_source.backward()  # DO NOT STEP YET need to accumulate more
            # on target
        pred_target = discriminator(feat_target.detach())  # don't calc feat_target gradient
        lossD_target = loss_fn_disc(pred_target, y_t)
        lossD_target.backward()  # accumulate gradients
        lossD = lossD_source + lossD_target
        optimizer_disc.step()
        
        # train the target feature extractor
        optimizer_feat.zero_grad()
        pred_target = discriminator(feat_target)
        lossF = loss_fn_feat(pred_target, y_s)  # loss with reversed labels
        lossF.backward()
        optimizer_feat.step() 
        loss_disc.append(lossD.item())
        loss_feat.append(lossF.item())

        if (batch+1) % 50 == 0:
            current_sample = (batch+1)*batch_size
            print(f"discrim loss: {lossD}, feat extract loss: {lossF}, sample {current_sample}/{total_batches*batch_size}")
    # configure function outputs 
    inter_iteration_stuff["iter_source"] = iter_source
    if "loss_disc" not in inter_iteration_stuff:
        inter_iteration_stuff["loss_disc"] = []  # start if not already there
    if "loss_feat" not in inter_iteration_stuff:
        inter_iteration_stuff["loss_feat"] = [] 
    inter_iteration_stuff["loss_disc"] += loss_disc
    inter_iteration_stuff["loss_feat"] += loss_feat
    
    return inter_iteration_stuff

In [72]:
# initialize the mms feature extractor with the one for the sim data. should be ok starting point
# optimizer for target feature extractor need both it and discriminator for correct backpropagation?

In [73]:
def create_adda(mms_num_conv, mms_kp_limit, mms_kernel_size, mms_pool_size, mms_out_channels, mms_learning_rate, mms_dropout_fraction,
                discrim_width, discrim_length, feat_shape):
    discrim = Discriminator(discrim_width, discrim_length).to(device=device, dtype=torch.double)
    feat_mms = MMSFeatExtract(mms_num_conv, mms_kp_limit, mms_kernel_size, mms_pool_size, mms_out_channels, mms_learning_rate,
                              mms_dropout_fraction, feat_shape).to(device=device, dtype=torch.double)
    feat_mms.apply(weights_init)  # apply gaussian weights
#    print(discrim)
#    print(feat_mms)
    return discrim, feat_mms

### If run alone: save a model with random weights

In [None]:
if __name__ == '__main__' and '__file__' not in globals():  # do not run if %run from another notebook
    device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
    print(f"Using {device} device")
    dtype = torch.double
    %run /tigress/kendrab/analysis-notebooks/torch_models/utils.ipynb
    # some sim model hyperparameters
    padding_length = 39  # amount of data on each side of each segment for additional info
    stride = 11  # size (and therefore spacing) of each segment
    input_length = stride + 2*padding_length
    batch_size = 11
    mms_kp_limit = 9  # not neeeded when loading from file
    mms_num_conv = 1
    mms_kernel_size = 3
    mms_pool_size = 4
    mms_out_channels = 40
    mms_learning_rate = 0.0012094769607738786
    mms_dropout_fraction = 0.03457536835651724
    feat_shape = [batch_size, 72, 3]
    mock_data = torch.ones((batch_size, 1, input_length), dtype=dtype)
    feat_mms = MMSFeatExtract(mms_num_conv, mms_kp_limit, mms_kernel_size, mms_pool_size, mms_out_channels, mms_learning_rate,
                              mms_dropout_fraction, feat_shape).to(device=device, dtype=torch.double)
    # dry run to initialize lazy modules
    feat_mms(*[mock_data for i in range(7)])
    print(torchinfo.summary(feat_mms))
    feat_mms.apply(lambda m: weights_init(m, mean=-0.16, stdev=1.46))  # apply gaussian weights  
    torch.save(feat_mms.state_dict(), "/tigress/kendrab/analysis-notebooks/model_outs/"+"mms_random_featextract_statedict.tar")