used with [domain_adaptation_1](./domain_adaptation_1.ipynb)

## Feature extractor model

In [2]:
# hyperparameters
padding_length = 10  # amount of data on each side of each segment for additional info
stride = 10  # size (and therefore spacing) of each segment
input_length = stride + 2*padding_length
kernel_size = 3
pool_size = 2
out_channels = 32  # like 'filters' in keras

discrim_width = 60

In [None]:
class MMSFeatExtract(nn.Module): #TODO: leaky relu? strided convolution instead of pooling?
    """ 1D CNN Model """
    def __init__(self):
        super().__init__()
        # define these all separately because they will get different weights
        # consider smooshing these together into one convolution with in_channels=6. Idk if a good idea
        self.bx_layers = nn.Sequential(nn.Conv1d(1, out_channels, kernel_size, padding='valid'),
                                       nn.LeakyReLU(),
                                       nn.AvgPool1d(pool_size))
        self.by_layers = nn.Sequential(nn.Conv1d(1, out_channels, kernel_size, padding='valid'),
                                       nn.LeakyReLU(),
                                       nn.AvgPool1d(pool_size))
        self.bz_layers = nn.Sequential(nn.Conv1d(1, out_channels, kernel_size, padding='valid'),
                                       nn.LeakyReLU(),
                                       nn.AvgPool1d(pool_size))
        self.ex_layers = nn.Sequential(nn.Conv1d(1, out_channels, kernel_size, padding='valid'),
                                       nn.LeakyReLU(),
                                       nn.AvgPool1d(pool_size))
        self.ey_layers = nn.Sequential(nn.Conv1d(1, out_channels, kernel_size, padding='valid'),
                                       nn.LeakyReLU(),
                                       nn.AvgPool1d(pool_size))
        self.ez_layers = nn.Sequential(nn.Conv1d(1, out_channels, kernel_size, padding='valid'),
                                       nn.LeakyReLU(),
                                       nn.AvgPool1d(pool_size))
        self.jy_layers = nn.Sequential(nn.Conv1d(1, out_channels, kernel_size, padding='valid'),
                                       nn.LeakyReLU(),
                                       nn.AvgPool1d(pool_size))
        
        # TODO split this into CNN and classifier parts to facilitate domain adaptation
        self.post_merge_layers = nn.Sequential(nn.Conv1d(out_channels, out_channels*2, kernel_size,
                                                         padding='valid'),
                                               nn.LeakyReLU(),
                                               nn.AvgPool1d(pool_size))
                                               

    def forward(self, bx, by, bz, ex, ey, ez, jy):
        bx_proc = self.bx_layers(bx)
        by_proc = self.by_layers(by)
        bz_proc = self.bz_layers(bz)
        ex_proc = self.ex_layers(ex)
        ey_proc = self.ey_layers(ey)
        ez_proc = self.ez_layers(ez)
        jy_proc = self.jy_layers(jy)
        combined = (bx_proc + by_proc + bz_proc + ex_proc + ey_proc + ez_proc + jy_proc)/6.
        features = self.post_merge_layers(combined)
        
        return features

## Discriminator part

In [None]:
class Discriminator(nn.Module):  # TODO: use leaky ReLU instead? advised for dcgan
    """ Based on the Tzeng et al. model """
    def __init__(self):
        super().__init__()        
        self.layer1 = nn.Sequential(nn.LazyLinear(discrim_width), nn.LeakyReLU())
        self.layer2 = nn.Sequential(nn.Linear(discrim_width, discrim_width), nn.LeakyReLU())
        self.domain_label = nn.Sequential(nn.Linear(discrim_width, 1), nn.LeakyReLU())
        
    def forward(self,features):
        batch_size = features.shape[0]
        layer1_out = self.layer1(features.view(batch_size,-1))
        layer2_out = self.layer2(layer1_out)
        domain_pred_logit = self.domain_label(layer2_out).view(-1)
        
        return domain_pred_logit

## Training loop

In [None]:
def train_adda(dl_source, dl_target, feat_extract_source, feat_extract_target, discriminator,
               loss_fn, optimizer_disc, optimizer_feat, iter_source=None):
    # note feat_extract_source is a special boye that returns a dictionary that we need to get a value from
    # backpropagate feat extractor (target) w/ GAN loss
    # backpropagate discriminator w/ cross entropy loss (log? am confuse)
    inter_iteration_stuff = {}  # whatever I need to escape this function (losses over time, the current iterator, etc.)
    feat_extract_source.eval()
    batch_size = dl_source.batch_size  # the length of a tensordataset is the batch size (shared first dim)
    y_s = torch.zeros(batch_size, device=device, dtype=dtype)  # labels. target = 1, source = 0.
    y_t = torch.ones(batch_size, device=device, dtype=dtype)
    
    # let's iterate enough that the mms dataset is completely used. 
    samples_source = len(dl_source.dataset)
    samples_target = len(dl_target.dataset)
    total_batches =samples_target//batch_size
    print(f"training on {samples_source} source samples and {samples_target} target samples")
    # iterate
    if iter_source is None:  # make the source iterator if it doesn't already exist
        print("Making source iterator")  # shouldn't ever fire with current setup / needs from the code
        iter_source = iter(dl_source)
    iter_target = iter(dl_target)
    for batch in range(total_batches):
        ds_source, iter_source = iter_or_restart_dl(dl_source, iter_source)
        ds_target, iter_target = iter_or_restart_dl(dl_target, iter_target)
        loss_disc = []
        loss_feat = []
        # unpack values
        _, _, bx_s, by_s, bz_s, ex_s, ey_s, ez_s, jy_s, _, _, _ = ds_source
        bx_t, by_t, bz_t, ex_t, ey_t, ez_t, _, jy_t, _, _ = ds_target
        # calculate features, add labels
        feat_source = feat_extract_source(bx_s, by_s, bz_s, ex_s, ey_s, ez_s, jy_s)["features"].detach()  # don't calc gradient
        feat_target = feat_extract_target(bx_t, by_t, bz_t, ex_t, ey_t, ez_t, jy_t)
        
        feat_extract_target.train()
        discriminator.train()
        
        # train the discriminator
        optimizer_disc.zero_grad()
            # on source
        pred_source = discriminator(feat_source)
        lossD_source = loss_fn(pred_source,y_s)
        lossD_source.backward()  # DO NOT STEP YET need to accumulate more
            # on target
        pred_target = discriminator(feat_target.detach())  # don't calc feat_target gradient
        lossD_target = loss_fn(pred_target, y_t)
        lossD_target.backward()  # accumulate gradients
        lossD = lossD_source + lossD_target
        optimizer_disc.step()
        
        # train the target feature extractor
        optimizer_feat.zero_grad()
        pred_target = discriminator(feat_target)
        lossF = loss_fn(pred_target, y_s)  # loss with reversed labels
        lossF.backward()
        optimizer_feat.step() 
        loss_disc.append(lossD.item())
        loss_feat.append(lossF.item())

        if (batch+1) % 50 == 0:
            current_sample = (batch+1)*batch_size
            print(f"discrim loss: {lossD}, feat extract loss: {lossF}, sample {current_sample}/{total_batches*batch_size}")
    # configure function outputs 
    inter_iteration_stuff["iter_source"] = iter_source
    if "loss_disc" not in inter_iteration_stuff:
        inter_iteration_stuff["loss_disc"] = []  # start if not already there
    if "loss_feat" not in inter_iteration_stuff:
        inter_iteration_stuff["loss_feat"] = [] 
    inter_iteration_stuff["loss_disc"] += loss_disc
    inter_iteration_stuff["loss_feat"] += loss_feat
    
    return inter_iteration_stuff

In [None]:
# initialize the mms feature extractor with the one for the sim data. should be ok starting point
# optimizer for target feature extractor need both it and discriminator for correct backpropagation?

In [None]:
discrim = Discriminator().to(device=device, dtype=torch.double)
feat_mms = MMSFeatExtract().to(device=device, dtype=torch.double)
feat_mms.apply(weights_init)
print(discrim)
print(feat_mms)