In [1]:
import torch
from torch import nn
from torch.nn import functional as F

In [43]:
class Model1(nn.Module):
    def __init__(self,in_f,mid_f):
        super().__init__()
        self.listm= nn.ModuleList([
            nn.Linear(in_f*i,in_f*(i+1)) for i in range(1,5)
        ])
        self.head=nn.Linear(in_f*5,mid_f)

    def forward(self,x: torch.Tensor):
        for layer in self.listm:
            x=layer(x)
        return self.head(x)

In [44]:
class Model2(nn.Module):
    def __init__(self,mid_f,out_f):
        super().__init__()
        self.listm= nn.ModuleList([
            nn.Linear(mid_f*i,mid_f*(i+1)) for i in range(1,5)
        ])
        self.head=nn.Linear(mid_f*5,out_f)

    def forward(self,x: torch.Tensor):
        for layer in self.listm:
            x=layer(x)
        return self.head(x)

In [45]:
IN_F=8
MID_F=128
OUT_F=8
BATCH_SIZE = 32

STEPS = 50


In [71]:
torch.manual_seed(42)
cuda = False
if torch.cuda.is_available() and torch.cuda.device_count >= 2:
    cuda = True
# Model
part1 = Model1(IN_F, MID_F)
part2 = Model2(MID_F, OUT_F)

#Data
fixed_input =torch.rand(BATCH_SIZE,IN_F)
fixed_output =torch.rand(BATCH_SIZE,OUT_F)

if cuda:
    part1=part1.to("cuda:0")
    part2=part2.to("cuda:1")
    fixed_input=fixed_input.to("cuda:0")
    fixed_output=fixed_output.to("cuda:1")

optimizer = torch.optim.Adam(list(part1.parameters()) + list(part2.parameters()), lr=3e-4)

In [72]:
loss_fn = nn.CrossEntropyLoss()
import time

In [73]:

def trainer(model1,model2,optim,step,fixed_input,fixed_output,loss_fn,is_cuda=False):
    model1.train()
    model2.train()
    st=time.time()
    for i in range(step):
        hidden=model1(fixed_input)
        if is_cuda:
            hiden=hidden.to("cuda:1")
        hidden.retain_grad()
        logits=model2(hidden)
        loss=loss_fn(logits, fixed_output)
        loss.backward()
        optim.step()
        optim.zero_grad()
        if step%5==0:
            print(f"Final Loss: {loss.item():.6f} ")
        
    duration=time.time()-st
    print(f"Final Loss: {loss.item():.6f} | Time: {duration:.3f}s")
        

In [74]:
trainer(part1,part2,optimizer,STEPS,fixed_input,fixed_output,loss_fn,is_cuda=cuda)

Final Loss: 8.090343 
Final Loss: 8.071132 
Final Loss: 8.060573 
Final Loss: 8.057980 
Final Loss: 8.060626 
Final Loss: 8.061011 
Final Loss: 8.058400 
Final Loss: 8.055301 
Final Loss: 8.053020 
Final Loss: 8.051542 
Final Loss: 8.050324 
Final Loss: 8.048800 
Final Loss: 8.046567 
Final Loss: 8.043437 
Final Loss: 8.039425 
Final Loss: 8.034662 
Final Loss: 8.029222 
Final Loss: 8.023099 
Final Loss: 8.016482 
Final Loss: 8.009660 
Final Loss: 8.003029 
Final Loss: 7.997621 
Final Loss: 7.994514 
Final Loss: 7.993891 
Final Loss: 7.993565 
Final Loss: 7.990621 
Final Loss: 7.986303 
Final Loss: 7.982062 
Final Loss: 7.978249 
Final Loss: 7.976545 
Final Loss: 7.976016 
Final Loss: 7.976521 
Final Loss: 7.976678 
Final Loss: 7.975892 
Final Loss: 7.974224 
Final Loss: 7.972164 
Final Loss: 7.970149 
Final Loss: 7.967823 
Final Loss: 7.965805 
Final Loss: 7.963854 
Final Loss: 7.961952 
Final Loss: 7.960265 
Final Loss: 7.958597 
Final Loss: 7.957114 
Final Loss: 7.955358 
Final Loss