In [71]:
import os
import sys
import numpy as np
import torch

from datetime import datetime
from typing import Tuple

from torch.nn import Module
import torch.nn.functional as F
from torch.nn import KLDivLoss, CrossEntropyLoss, CosineEmbeddingLoss
from torch.optim import Adam
from torch.optim.lr_scheduler import ReduceLROnPlateau

### Loading Teacher model ---> CLIP Image Extractor 

In [None]:
import clip

model_name = "ViT-B/32"

# model is the torch model.
# preprocess function is for image preprocessing.

model, preprocess = clip.load(model_name)

# Get only the visual model
teacher_model = model.visual
input_resolution = model.visual.input_resolution

print(
    "Model parameters:",
    f"{np.sum([int(np.prod(p.shape)) for p in model.visual.parameters()]):,}",
)
print("Input resolution:", input_resolution)

### Instantiating Student model 

[VisionTransformer](https://github.com/openai/CLIP/blob/d50d76daa670286dd6cacf3bcd80b5e4823fc8e1/clip/model.py#L206)

In [None]:
from clip.model import VisionTransformer
from clip.model import convert_weights # Make them float16
# Set Student Configuration

patch_size = 32
width = 384
layers = 6
heads = 12
output_dim = 512

student_model = VisionTransformer(
    input_resolution=input_resolution,
    patch_size=patch_size,
    width=width,
    layers=layers,
    heads=heads,
    output_dim=output_dim,
)



convert_weights(student_model)


print(
    "Model parameters:",
    f"{np.sum([int(np.prod(p.shape)) for p in student_model.parameters()]):,}",
)


### Load the WIT Dataset

In [None]:
from concurrent.futures import ThreadPoolExecutor
from functools import partial
import io
import urllib

import PIL.Image

from datasets import load_dataset
from datasets.utils.file_utils import get_datasets_user_agent


def fetch_single_image(image_url, timeout=None, retries=0):
    for _ in range(retries + 1):
        try:
            request = urllib.request.Request(
                image_url,
                data=None,
                headers={"user-agent": get_datasets_user_agent()},
            )
            with urllib.request.urlopen(request, timeout=timeout) as req:
                image = PIL.Image.open(io.BytesIO(req.read()))
            break
        except Exception:
            image = None
    return image


def fetch_images(batch, num_threads, timeout=None, retries=0):
    fetch_single_image_with_args = partial(
        fetch_single_image, timeout=timeout, retries=retries
    )
    with ThreadPoolExecutor(max_workers=num_threads) as executor:
        batch["image"] = list(
            executor.map(fetch_single_image_with_args, batch["image_url"])
        )
    return batch


num_threads = 20
dset = load_dataset("cifar10")
# dset = dset.map(
#     fetch_images, batched=True, batch_size=100, fn_kwargs={"num_threads": num_threads}
# )

In [None]:
import torchvision
import torchvision.transforms as transforms

transform = transforms.Compose([
                transforms.Resize(224),
                transforms.RandomCrop(224),
                #transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

cifar100 = torchvision.datasets.CIFAR100('data/',download=True,train=True,transform=transform)
train_dataloader = torch.utils.data.DataLoader(cifar100,
                                          batch_size=4,
                                          shuffle=True,
                                          num_workers=8)

In [None]:
class DistillationTrainer:
    def __init__(self, *args, **kwargs):
        self.teacher = teacher_model
        self.student = student_model
        self.preprocess = preprocess
        self.train_dataloader = train_dataloader

        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.teacher = self.teacher.to(self.device)
        self.student = self.student.to(self.device)
        self.teacher.eval()

        self.epochs = 30
        self.start_epoch = 1

        # set up optimizer
        self.optimizer = Adam(self.student.parameters(),lr=0.0001)

        # Set up LR Scheduler
        self.lr_scheduler = ReduceLROnPlateau(self.optimizer, "min")

    def compute_loss(self, images, return_outputs=False):
        images = images.to(self.device).half()
        outputs_student = self.student(images)

        # compute teacher output
        with torch.no_grad():
            outputs_teacher = self.teacher(images)

        # assert size
        assert outputs_student.size() == outputs_teacher.size()

        # Soften probabilities and compute distillation loss
        
        print("Before Loss ")
        print("outputs_student",outputs_student)
        print("outputs_teacher",outputs_teacher)
        # KL Divergence Loss
        loss = 0 
#         kl_loss = KLDivLoss(reduction="batchmean",log_target =True)
#         loss = kl_loss(F.log_softmax(outputs_student),F.log_softmax(outputs_teacher))
        # Cosine loss
        loss = loss + CosineEmbeddingLoss()(
            outputs_teacher, outputs_student, torch.ones(outputs_teacher.size()[0]).to(self.device)
        )

        return  loss

    def train(self):
        for epoch in range(self.start_epoch, self.epochs + 1):
            loss_value = self._train_epoch(epoch)
            print(f"KLD-CosineLoss after {epoch} Epoch is {loss_value}")

    def _train_epoch(self, epoch):
        loss_value = 0
        for batch_idx, (images, _) in enumerate(self.train_dataloader):
            
            loss = self.compute_loss( images)
            print(loss)
            loss_value += loss

            self.optimizer.zero_grad()
            
            loss.backward()
            
            self.optimizer.step()
            
            torch.autograd.set_detect_anomaly(True)
            
            if batch_idx % 10 == 0:
                print(f"Loss after {batch_idx} Batch is {loss_value/(batch_idx+1)} ")

        return loss_value.detach().cpu().numpy() / len(self.train_dataloader)

In [None]:
Trainer = DistillationTrainer(
    teacher_model=teacher_model,
    student_model=student_model,
    train_dataloader=train_dataloader,
    preprocess = preprocess,
)

In [88]:
Trainer.train()

Before Loss 
outputs_student tensor([[-0.7388, -0.7339, -0.2544,  ...,  0.3210,  0.2468,  0.6230],
        [ 0.0182, -0.6816, -0.0752,  ..., -0.9097, -1.2627,  1.2275],
        [-0.7944, -0.9746,  0.2808,  ..., -0.1247, -0.6689,  1.4512],
        [-0.8433, -1.1465,  0.1843,  ..., -0.5142, -0.7632,  1.4170]],
       device='cuda:0', dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.2181,  0.0105, -0.4763,  ...,  1.1885,  0.0987,  0.0034],
        [-0.1395,  0.1727, -0.2703,  ...,  0.8687, -0.2437, -0.0152],
        [ 0.2505,  0.4880, -0.2688,  ...,  0.7422, -0.0636, -0.0924],
        [-0.0498, -0.2864, -0.3735,  ...,  0.6250,  0.0225,  0.3813]],
       device='cuda:0', dtype=torch.float16)
tensor(1., device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Loss after 0 Batch is 1.0 
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan,

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.0283, -0.0120, -0.0847,  ...,  0.9233,  0.2600,  0.0150],
        [ 0.2087,  0.1022, -0.4045,  ...,  0.6216,  0.0068,  0.2869],
        [ 0.2484,  0.3079, -0.1107,  ...,  0.8770,  0.1857,  0.0529],
        [ 0.0265,  0.0417, -0.1326,  ...,  0.5859, -0.0679, -0.2708]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.2573, -0.2397, -0.1908,  ...,  0.8755, -0.3230, -0.0560],
        [ 0.1700,  0.0986, -0.0720,  ...,  0.3979, -0.2922, -0.1670],
        [-0.0754,  0.2959, -0.2725,  ...,  0.9336,  0.0854,  0.0665],
        [ 0.2778,  0.4514, -0.1006,  ...,  0.7769,  0.1198,  0.2074]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.2537,  0.1090, -0.1887,  ...,  1.0518,  0.1171,  0.1809],
        [ 0.3992,  0.3145, -0.5615,  ...,  0.9976,  0.0469, -0.4189],
        [-0.2136, -0.2725, -0.2188,  ...,  0.6831, -0.2729, -0.1305],
        [ 0.3408, -0.3760, -0.0490,  ...,  0.5918,  0.3379,  0.1055]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Loss after 40 Batch is nan 
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBack

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.3262,  0.2153, -0.2026,  ...,  0.6260, -0.2301, -0.1753],
        [ 0.1909, -0.0145, -0.4438,  ...,  0.6968, -0.2983,  0.3940],
        [ 0.0503, -0.2079, -0.1667,  ...,  0.8462, -0.0712, -0.3184],
        [ 0.0659, -0.0537, -0.4695,  ...,  0.7861, -0.2411,  0.0692]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.4221, -0.1310, -0.1313,  ...,  0.7168,  0.4021, -0.2155],
        [ 0.4417, -0.0359, -0.5215,  ...,  0.7373,  0.0177,  0.0183],
        [ 0.0659,  0.3062, -0.1716,  ...,  0.3779, -0.3167,  0.2356],
        [ 0.2915, -0.0237,  0.2164,  ...,  0.4009, -0.0215,  0.0459]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[-0.0018, -0.0406, -0.1089,  ...,  0.9023,  0.0686, -0.1163],
        [ 0.5811,  0.1555, -0.2935,  ...,  0.2761, -0.1069, -0.0193],
        [ 0.1388, -0.0073, -0.4329,  ...,  0.5063, -0.2106,  0.2981],
        [-0.1649,  0.1890, -0.0089,  ...,  0.6846, -0.1875,  0.0630]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.2942,  0.2988, -0.2281,  ...,  0.7021,  0.0907,  0.0073],
        [ 0.1077, -0.0622, -0.3552,  ...,  0.8599, -0.0929,  0.6753],
        [ 0.0755,  0.2827, -0.4155,  ...,  0.5850, -0.4160,  0.1060],
        [ 0.2468,  0.1686, -0.3184,  ...,  0.9062,  0.1344,  0.2146]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.1222,  0.2323, -0.4365,  ...,  0.8423, -0.2612,  0.1562],
        [ 0.2837, -0.2761, -0.3459,  ...,  1.1279,  0.1803,  0.1252],
        [ 0.0946,  0.0320, -0.1660,  ...,  0.8491, -0.3123,  0.0050],
        [ 0.2047, -0.1938, -0.4470,  ...,  0.3745,  0.4199,  0.0745]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Loss after 100 Batch is nan 
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBac

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.0843,  0.3347, -0.2329,  ...,  1.0215, -0.0173,  0.1510],
        [ 0.2075,  0.1782, -0.3813,  ...,  0.8384, -0.0502,  0.1803],
        [ 0.3362, -0.3584, -0.4229,  ...,  0.6201, -0.1327,  0.2661],
        [ 0.2494,  0.0798, -0.6045,  ...,  0.8892,  0.3042, -0.0064]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[-0.2174, -0.0626, -0.0763,  ...,  0.6982,  0.0047,  0.0687],
        [ 0.4675,  0.3411, -0.5552,  ...,  0.5112, -0.1395,  0.0980],
        [ 0.4355, -0.0020, -0.1453,  ...,  0.8247, -0.1774,  0.1559],
        [ 0.2360,  0.0862, -0.3831,  ...,  0.5308,  0.0409, -0.0129]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.1213,  0.1213, -0.3684,  ...,  0.7354,  0.2175,  0.0378],
        [-0.0175, -0.0043, -0.1377,  ...,  0.7905, -0.2705, -0.0643],
        [ 0.2161,  0.3022, -0.3228,  ...,  0.5293,  0.1306,  0.0537],
        [ 0.3003,  0.2303, -0.3301,  ...,  0.7583, -0.2181,  0.1886]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.0884,  0.1968, -0.2450,  ...,  0.4783,  0.0410, -0.0078],
        [ 0.0864, -0.2947, -0.3000,  ...,  1.1885, -0.0049,  0.3164],
        [ 0.2798,  0.4580, -0.3313,  ...,  0.9878, -0.1112, -0.3921],
        [-0.1815,  0.7412,  0.0594,  ...,  0.8359,  0.2125, -0.1224]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.0357,  0.0327, -0.4036,  ...,  0.5991, -0.3518,  0.4868],
        [ 0.1973, -0.0895, -0.1920,  ...,  0.6582, -0.0410,  0.4045],
        [ 0.4688,  0.1804, -0.3315,  ...,  0.9844, -0.0691,  0.2139],
        [ 0.6460,  0.4587, -0.3286,  ...,  0.7988,  0.0798, -0.1190]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Loss after 160 Batch is nan 
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBac

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.0743,  0.0562, -0.5503,  ...,  0.8384,  0.0341, -0.3420],
        [-0.2441,  0.0913, -0.2576,  ...,  1.0781, -0.2218,  0.0026],
        [ 0.4128,  0.1919, -0.3225,  ...,  0.7007, -0.3152,  0.1254],
        [ 0.1956, -0.0915, -0.2223,  ...,  0.5767, -0.0129,  0.0676]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.3499,  0.3447, -0.2695,  ...,  0.9731, -0.1398, -0.1249],
        [ 0.3657, -0.3572, -0.0832,  ...,  0.7095,  0.1293,  0.0179],
        [ 0.1091,  0.4094,  0.0244,  ...,  0.6353,  0.0663,  0.0873],
        [-0.0577, -0.0014, -0.3242,  ...,  0.8643, -0.3413,  0.2134]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.3950, -0.0859, -0.4282,  ...,  0.7783,  0.2842,  0.2029],
        [ 0.3735,  0.3171, -0.2595,  ...,  0.5425, -0.1249,  0.1687],
        [ 0.3149,  0.2615, -0.6763,  ...,  0.6064,  0.0900,  0.1866],
        [ 0.1381,  0.1337, -0.1836,  ...,  0.7217, -0.0807,  0.2480]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.5649,  0.0577, -0.4734,  ...,  0.4490, -0.3196,  0.2019],
        [ 0.2235, -0.1448, -0.3032,  ...,  0.5513,  0.3279, -0.2231],
        [-0.0406,  0.0936, -0.1919,  ...,  0.4138,  0.1624, -0.2944],
        [ 0.4465, -0.0071, -0.2866,  ...,  0.6523,  0.0776,  0.1134]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[-0.0953,  0.1616, -0.1224,  ...,  1.0303, -0.2054, -0.0904],
        [ 0.3323,  0.0462, -0.5581,  ...,  0.5830,  0.4346,  0.2727],
        [ 0.0288,  0.3518, -0.2834,  ...,  0.6919,  0.1913,  0.0408],
        [ 0.2205,  0.6016, -0.1608,  ...,  0.4111,  0.3235,  0.6714]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Loss after 220 Batch is nan 
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBac

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.4004,  0.0692, -0.3579,  ...,  0.6997,  0.1685,  0.4861],
        [ 0.4082,  0.0373, -0.5508,  ...,  0.7788,  0.1854,  0.0816],
        [ 0.5732,  0.1144, -0.1377,  ...,  0.4141, -0.1488,  0.3550],
        [ 0.1792,  0.0265, -0.4614,  ...,  0.7217, -0.3486,  0.1997]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.0665,  0.1785, -0.2297,  ...,  0.8394,  0.0335, -0.0644],
        [ 0.0792, -0.2164, -0.3020,  ...,  0.9111,  0.0600, -0.0988],
        [ 0.2251,  0.0720, -0.5361,  ...,  0.5078,  0.1796,  0.2079],
        [-0.0323,  0.1415, -0.3181,  ...,  0.9546,  0.0088, -0.2164]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.0848,  0.1837, -0.0776,  ...,  1.0674, -0.1975, -0.2458],
        [ 0.2646,  0.2192, -0.4856,  ...,  0.7822,  0.0592,  0.0188],
        [ 0.1611,  0.2186, -0.4768,  ...,  1.0068, -0.2047,  0.0129],
        [ 0.0682, -0.1050, -0.2194,  ...,  0.6055, -0.1121,  0.4690]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.4534,  0.2069, -0.4365,  ...,  0.9541, -0.1096,  0.1472],
        [ 0.1111,  0.2346, -0.4495,  ...,  1.1465, -0.0959,  0.5015],
        [ 0.0158,  0.3926, -0.0844,  ...,  0.3921, -0.1632, -0.1473],
        [ 0.1621,  0.1167, -0.3240,  ...,  0.5815, -0.1024,  0.1971]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.1110,  0.2384, -0.3806,  ...,  0.7593, -0.0865, -0.0942],
        [ 0.2544,  0.3115, -0.2791,  ...,  0.7554,  0.0620,  0.0670],
        [ 0.4124,  0.0766, -0.3940,  ...,  0.5083,  0.1853,  0.0496],
        [ 0.3198,  0.7090, -0.0996,  ...,  0.8037, -0.1349, -0.0156]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Loss after 280 Batch is nan 
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBac

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[-0.0039, -0.1748, -0.2815,  ...,  0.8003,  0.1131,  0.1088],
        [ 0.6719,  0.2389, -0.2133,  ...,  0.9131, -0.1439,  0.3513],
        [ 0.4553, -0.2642, -0.0801,  ...,  0.6699,  0.1589,  0.0866],
        [ 0.2101,  0.1005, -0.2406,  ...,  0.3950,  0.2126,  0.0941]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.1099, -0.0896, -0.2350,  ...,  0.5205,  0.4050,  0.2932],
        [ 0.1611, -0.0676, -0.3198,  ...,  1.0088, -0.1244,  0.2791],
        [ 0.2607,  0.1860,  0.0477,  ...,  1.0547, -0.0109,  0.0146],
        [ 0.2222, -0.0504, -0.1185,  ...,  0.8145,  0.0872,  0.2966]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[-0.2362,  0.2330, -0.1348,  ...,  1.0791,  0.1162,  0.0871],
        [ 0.1118,  0.3142, -0.3254,  ...,  0.6685, -0.1559, -0.0632],
        [ 0.2639,  0.3904, -0.1418,  ...,  0.4517,  0.1456,  0.3000],
        [-0.0164,  0.2712,  0.0329,  ...,  0.6401, -0.1688, -0.1396]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.2915,  0.2230, -0.3806,  ...,  0.7427, -0.3110, -0.3638],
        [ 0.2703,  0.4612, -0.3132,  ...,  0.1287,  0.1307,  0.3713],
        [ 0.1923,  0.6147, -0.1898,  ...,  0.7725,  0.3164,  0.2598],
        [ 0.0182,  0.2137, -0.1649,  ...,  0.8506, -0.1558,  0.2805]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.1940,  0.3052, -0.2856,  ...,  0.9517, -0.4058,  0.1786],
        [ 0.1824,  0.2072, -0.2722,  ...,  0.5718, -0.2177, -0.0728],
        [ 0.0350,  0.3000, -0.3950,  ...,  0.8389, -0.0477,  0.2986],
        [ 0.1002,  0.1667, -0.1539,  ...,  0.5811, -0.2656,  0.0605]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Loss after 340 Batch is nan 
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBac

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.5918,  0.4319, -0.4673,  ...,  0.4854, -0.3196,  0.6304],
        [ 0.3240,  0.3069, -0.1243,  ...,  0.1174,  0.3569,  0.1121],
        [-0.0050, -0.0040, -0.1183,  ...,  0.6670, -0.4399,  0.2520],
        [ 0.3098,  0.1220, -0.2343,  ...,  0.8442,  0.0850,  0.0356]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.1642,  0.1292, -0.0091,  ...,  0.8711,  0.2013,  0.2043],
        [-0.0026,  0.1231, -0.3096,  ...,  1.0410, -0.2291, -0.0896],
        [ 0.2703, -0.1068, -0.3142,  ...,  0.4937,  0.3110, -0.2001],
        [ 0.3125,  0.3447, -0.5088,  ..., -0.1121,  0.3674,  0.0393]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[-0.0137,  0.2306, -0.2272,  ...,  1.0205,  0.1464,  0.3948],
        [ 0.2803,  0.0350, -0.3518,  ...,  0.7734,  0.1660,  0.0842],
        [ 0.0523,  0.0123, -0.0828,  ...,  0.8242, -0.3948, -0.0768],
        [ 0.4563,  0.2546, -0.2759,  ...,  0.4194,  0.0107,  0.2500]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.2668,  0.2969, -0.4604,  ...,  0.8530, -0.0620,  0.2061],
        [ 0.3228, -0.1537, -0.4272,  ...,  0.4978, -0.0934,  0.4260],
        [ 0.1270,  0.3394, -0.3438,  ...,  0.8667, -0.1125,  0.1228],
        [ 0.1962, -0.0195, -0.4856,  ...,  0.4006,  0.1458,  0.2878]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.3638, -0.3582, -0.0299,  ...,  0.4155,  0.4893,  0.0420],
        [ 0.1400,  0.0203, -0.4873,  ...,  1.0928, -0.0107,  0.1608],
        [ 0.3818,  0.1394, -0.5122,  ...,  0.8560,  0.0434, -0.0627],
        [-0.0114, -0.1846, -0.2939,  ...,  0.8179, -0.1744, -0.0080]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Loss after 400 Batch is nan 
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBac

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.2766, -0.0231, -0.3420,  ...,  0.7397,  0.0091,  0.0700],
        [ 0.1057,  0.0932, -0.1937,  ...,  0.8896, -0.4983,  0.0247],
        [ 0.3772, -0.3162, -0.3193,  ...,  0.6934,  0.1843,  0.3586],
        [ 0.3521,  0.1545, -0.0533,  ...,  0.6270,  0.0039, -0.1801]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.1382,  0.0324, -0.4243,  ...,  0.6416, -0.1260,  0.3042],
        [ 0.2766,  0.0411, -0.3950,  ...,  0.8057, -0.0829,  0.2100],
        [ 0.2246, -0.1440, -0.3630,  ...,  0.9727, -0.2627,  0.2482],
        [ 0.2080,  0.2054, -0.3303,  ...,  0.7837, -0.1079,  0.1744]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Loss after 440 Batch is nan 
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.2190,  0.0917, -0.4634,  ...,  0.5698, -0.1357, -0.0370],
        [ 0.2634,  0.1168, -0.5898,  ...,  0.6265, -0.2219,  0.0805],
        [ 0.1849,  0.0882, -0.0884,  ...,  0.6782,  0.1061,  0.0961],
        [ 0.5625,  0.4153, -0.2917,  ...,  0.5005, -0.3589,  0.4351]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBac

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[-0.0009,  0.1393, -0.3201,  ...,  0.6606, -0.1754, -0.0481],
        [ 0.0745,  0.0186, -0.6094,  ...,  0.3989,  0.2009,  0.2764],
        [ 0.4614,  0.1141, -0.2047,  ...,  0.5151,  0.0705,  0.0319],
        [ 0.1248, -0.0986, -0.0528,  ...,  0.5044, -0.4172, -0.1536]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.0477,  0.1547, -0.3469,  ...,  0.2310,  0.0969,  0.1241],
        [ 0.3770,  0.5586,  0.1094,  ...,  0.6528,  0.0404,  0.2942],
        [ 0.3987,  0.5806, -0.0687,  ...,  0.3955,  0.0329,  0.1400],
        [ 0.1931,  0.3586, -0.3313,  ...,  0.7856,  0.1793,  0.1230]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.2467, -0.2996, -0.3020,  ...,  0.5186, -0.0984, -0.2854],
        [ 0.4749,  0.2832, -0.3979,  ...,  1.1113,  0.1486,  0.4668],
        [ 0.2141,  0.0140,  0.2045,  ...,  0.9922, -0.1162,  0.0333],
        [ 0.3657,  0.5786, -0.1685,  ...,  0.8032,  0.0418, -0.0180]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.3909, -0.0020, -0.0797,  ...,  0.3740,  0.1810, -0.0973],
        [ 0.3955,  0.2373, -0.2117,  ...,  0.6890, -0.0293,  0.2529],
        [ 0.3542,  0.1428, -0.8369,  ...,  0.6162, -0.1597,  0.2184],
        [ 0.2201,  0.1687, -0.3687,  ...,  0.4277, -0.0706,  0.0134]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.1160,  0.1550, -0.1224,  ...,  0.6367, -0.2629, -0.1270],
        [ 0.6899,  0.2576, -0.3921,  ...,  0.7896,  0.1182,  0.3335],
        [-0.0606, -0.0461, -0.0975,  ...,  0.7915, -0.3108, -0.1080],
        [ 0.4360,  0.2759, -0.2377,  ...,  0.8452, -0.1215,  0.0115]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 6.1462e-02, -3.3594e-01, -3.7231e-01,  ...,  1.1240e+00,
         -1.2445e-01,  9.7084e-04],
        [ 2.0947e-01, -3.4814e-01, -3.8605e-02,  ...,  6.4355e-01,
         -1.6089e-01, -2.1338e-01],
        [ 2.9492e-01, -2.8641e-02, -3.8721e-01,  ...,  8.3887e-01,
         -6.9763e-02,  2.8857e-01],
        [ 2.6147e-01,  1.1334e-01, -2.9712e-01,  ...,  3.3228e-01,
         -2.3328e-01,  3.4375e-01]], device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [n

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.3567,  0.0418, -0.2498,  ...,  0.3818,  0.1505, -0.0349],
        [ 0.3110, -0.1592, -0.2966,  ...,  0.5786, -0.1487,  0.3289],
        [ 0.1230,  0.1356, -0.3923,  ...,  1.0557,  0.0518,  0.2478],
        [ 0.4287,  0.0329, -0.2510,  ...,  0.6797,  0.2123,  0.1296]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.7651,  0.4519, -0.3376,  ...,  0.3140, -0.1620, -0.0934],
        [ 0.0219,  0.1270, -0.4014,  ...,  0.8770, -0.1516,  0.1001],
        [ 0.0753,  0.3611, -0.4556,  ...,  0.9326, -0.2103,  0.0116],
        [-0.0118,  0.1046, -0.1512,  ...,  0.9429, -0.0877, -0.0089]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.2661,  0.4543, -0.3030,  ...,  0.8267,  0.0140, -0.1231],
        [-0.3894, -0.0983, -0.0616,  ...,  0.7715, -0.3545, -0.2240],
        [ 0.1107, -0.0680, -0.3726,  ...,  0.7007, -0.0661,  0.2489],
        [ 0.2488, -0.1279, -0.0380,  ...,  0.7578, -0.1671, -0.0684]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[-0.0106,  0.1848, -0.2756,  ...,  0.9082,  0.0035, -0.1671],
        [ 0.0280, -0.0195, -0.0325,  ...,  0.8823, -0.0163, -0.0505],
        [ 0.0695, -0.0257, -0.2981,  ...,  1.1797, -0.0340, -0.2681],
        [ 0.2864,  0.0929, -0.0852,  ...,  0.6909,  0.1057,  0.2188]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[-0.0404,  0.1559, -0.3806,  ...,  0.8076, -0.2365,  0.3337],
        [ 0.4907, -0.1143, -0.2456,  ...,  0.7056, -0.0012,  0.1569],
        [ 0.0352,  0.0173, -0.3401,  ...,  1.0000, -0.1342, -0.0186],
        [ 0.4011,  0.0381,  0.0206,  ...,  0.3984, -0.1566,  0.0812]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.1172,  0.0485, -0.1625,  ...,  0.7368, -0.0712,  0.1803],
        [-0.2023, -0.1515, -0.2351,  ...,  0.6841, -0.4761, -0.2837],
        [ 0.2228,  0.2096, -0.2424,  ...,  0.5337,  0.2056,  0.0640],
        [ 0.1937, -0.0029, -0.2238,  ...,  0.4216,  0.1313, -0.1494]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.3779,  0.2206, -0.4319,  ...,  0.7017,  0.2822,  0.1338],
        [ 0.0011, -0.2925, -0.1987,  ...,  0.5581, -0.2423,  0.1422],
        [ 0.2178, -0.0397, -0.3047,  ...,  0.5659, -0.1199,  0.1057],
        [ 0.1482,  0.1920, -0.1379,  ...,  0.7148, -0.2808,  0.1604]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[-0.1067, -0.1467, -0.1243,  ...,  0.9990, -0.1074, -0.0272],
        [ 0.0374,  0.0719, -0.0740,  ...,  0.6963,  0.0574,  0.0632],
        [-0.1260, -0.0774, -0.2396,  ...,  0.7065, -0.1770, -0.2460],
        [-0.2041,  0.4512, -0.5044,  ...,  0.6284,  0.2944, -0.1466]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.4055,  0.1094, -0.1439,  ...,  0.5684, -0.1877,  0.3169],
        [ 0.0085,  0.0430, -0.3096,  ...,  0.6216,  0.0871, -0.0612],
        [ 0.1671, -0.1666, -0.3408,  ...,  0.5830, -0.0785,  0.3564],
        [ 0.2267, -0.0832, -0.1753,  ...,  0.7178,  0.3110,  0.0562]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.0398, -0.1118, -0.2625,  ...,  0.9287, -0.0765,  0.1164],
        [ 0.3464, -0.4553,  0.0207,  ...,  0.7871, -0.0479,  0.0854],
        [ 0.1859,  0.0530, -0.0196,  ...,  0.6152,  0.0323,  0.2739],
        [ 0.0948,  0.2428, -0.2037,  ...,  0.6133, -0.0996, -0.0147]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.7070,  0.0284, -0.3110,  ...,  0.8242, -0.0871,  0.2291],
        [ 0.4062,  0.0551, -0.5225,  ...,  0.6157, -0.0737,  0.1801],
        [ 0.0516,  0.3030, -0.2891,  ...,  0.4734,  0.0575, -0.2739],
        [-0.0582, -0.1447, -0.2812,  ...,  0.8198, -0.3494,  0.0587]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.0696,  0.2161, -0.2236,  ...,  0.4255, -0.0866,  0.0216],
        [ 0.1256,  0.1448, -0.2542,  ...,  0.9243, -0.1664, -0.1140],
        [ 0.0643,  0.1206, -0.4717,  ...,  0.8394, -0.2296, -0.0839],
        [-0.2361, -0.0848, -0.4419,  ...,  0.3350, -0.2291, -0.2087]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.4033,  0.2903, -0.3459,  ...,  0.6089, -0.0372, -0.0357],
        [ 0.2603,  0.1576, -0.4971,  ...,  0.4358,  0.1656, -0.1364],
        [ 0.2932,  0.1122, -0.3726,  ...,  0.8867, -0.0033,  0.0463],
        [-0.1154,  0.3257, -0.5029,  ...,  0.7461,  0.1205,  0.0232]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Loss after 690 Batch is nan 
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBac

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.6978,  0.0665, -0.1921,  ...,  0.7666,  0.1553,  0.1366],
        [ 0.0604, -0.1218, -0.3359,  ...,  0.8208,  0.0250,  0.1140],
        [ 0.3320, -0.2834, -0.1442,  ...,  0.4763, -0.4365,  0.3267],
        [ 0.5571,  0.2288, -0.2484,  ...,  0.5840, -0.1204, -0.0103]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[-0.0104,  0.1411, -0.1469,  ...,  0.5854,  0.0482,  0.2598],
        [ 0.6699, -0.0848, -0.5586,  ...,  0.6953,  0.4121,  0.1543],
        [ 0.2423,  0.0584, -0.3376,  ...,  0.8022, -0.5537,  0.3789],
        [ 0.0946, -0.0695, -0.0884,  ...,  0.9307, -0.0864,  0.1004]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.2844,  0.0914, -0.5356,  ...,  0.9722, -0.1307,  0.0368],
        [ 0.2532,  0.0597, -0.2080,  ...,  0.3506, -0.0180, -0.1775],
        [ 0.2874,  0.4231, -0.2664,  ...,  0.8340,  0.1644,  0.1672],
        [ 0.0039, -0.0806, -0.3442,  ...,  0.2437,  0.4548,  0.1086]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.1528,  0.3293, -0.2776,  ...,  0.1885, -0.0562,  0.4153],
        [ 0.1974, -0.0650, -0.4260,  ...,  0.6504, -0.4021,  0.1270],
        [ 0.2247, -0.0402, -0.4597,  ...,  0.9448,  0.0903,  0.3140],
        [ 0.2866,  0.3525, -0.3242,  ...,  0.4402,  0.0447,  0.2776]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.0013, -0.2468,  0.0575,  ...,  0.3870,  0.4707, -0.0594],
        [ 0.4819,  0.4771, -0.2544,  ...,  0.3997, -0.0169, -0.2568],
        [ 0.2695,  0.1174, -0.2925,  ...,  0.5562, -0.2190,  0.1736],
        [ 0.1298,  0.2942, -0.5234,  ...,  0.8408,  0.1321,  0.2715]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Loss after 750 Batch is nan 
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBac

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.1661,  0.1005, -0.6030,  ...,  1.1816, -0.0101, -0.2223],
        [ 0.0957,  0.2729, -0.3159,  ...,  0.8218,  0.1025,  0.1013],
        [ 0.1004,  0.0279, -0.1699,  ...,  0.8213, -0.1796, -0.2427],
        [ 0.0475,  0.1383, -0.2981,  ...,  0.4866,  0.0729,  0.2185]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.1124,  0.1469, -0.5449,  ...,  0.7104,  0.1401, -0.0343],
        [-0.0429, -0.1365, -0.2629,  ...,  0.7583, -0.1448,  0.1176],
        [ 0.1466, -0.0142, -0.3054,  ...,  0.7705, -0.1305,  0.1552],
        [-0.0821, -0.1394, -0.1877,  ...,  0.5898,  0.0827, -0.0710]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.1842,  0.2839, -0.4741,  ...,  0.6230,  0.0880, -0.0547],
        [ 0.2240,  0.0311, -0.1353,  ...,  0.3223, -0.2202, -0.0018],
        [ 0.1846,  0.2369, -0.3748,  ...,  0.9258,  0.0605,  0.3230],
        [-0.1611,  0.0283, -0.1143,  ...,  0.7681,  0.0191, -0.0623]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.5566, -0.2391, -0.3757,  ...,  0.6509,  0.2922,  0.1041],
        [ 0.1017,  0.0181, -0.3772,  ...,  1.1504, -0.0518, -0.0542],
        [ 0.0851,  0.1109, -0.3477,  ...,  0.8906,  0.0435, -0.0080],
        [ 0.1508,  0.2119, -0.1750,  ...,  0.7656, -0.1013, -0.0370]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.0398,  0.2230, -0.1732,  ...,  0.6934, -0.2253,  0.2573],
        [ 0.2020,  0.3069, -0.2218,  ...,  0.5205,  0.0283,  0.2593],
        [ 0.1536,  0.1877, -0.2705,  ...,  1.0352, -0.2720,  0.0208],
        [ 0.3135, -0.0225, -0.4827,  ...,  0.6782, -0.0818, -0.0464]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Loss after 810 Batch is nan 
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outp

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[-0.1326,  0.5000, -0.0482,  ...,  0.7383,  0.2971,  0.0697],
        [-0.1172,  0.6338, -0.5850,  ...,  0.8032, -0.0102,  0.1265],
        [-0.0371,  0.3469, -0.1682,  ...,  0.9517,  0.0992, -0.2067],
        [ 0.1565,  0.1885, -0.3787,  ...,  0.7910, -0.1589, -0.4048]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.2137,  0.2781, -0.4702,  ...,  0.5767,  0.0770, -0.1665],
        [-0.1242, -0.0136, -0.0351,  ...,  0.4688,  0.1570, -0.2155],
        [ 0.2766, -0.2338,  0.0449,  ...,  0.5273, -0.3022, -0.2830],
        [-0.1028, -0.1704, -0.1324,  ...,  0.8979, -0.0683,  0.2145]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.0586,  0.2529, -0.2312,  ...,  0.7656, -0.4924,  0.4680],
        [ 0.4988,  0.3372, -0.3250,  ...,  0.8687, -0.1428,  0.2441],
        [ 0.3220,  0.1588, -0.7329,  ...,  0.6553, -0.0200,  0.0712],
        [ 0.2141,  0.4146, -0.1746,  ...,  0.5913, -0.0163, -0.0976]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.1785,  0.6138, -0.2046,  ...,  0.6201, -0.1058, -0.0380],
        [ 0.2668,  0.1898, -0.1616,  ...,  0.5923,  0.2106, -0.1617],
        [ 0.1716, -0.0673, -0.2703,  ...,  0.7788, -0.3176,  0.1909],
        [-0.0626,  0.3005, -0.1956,  ...,  0.6934, -0.1350, -0.0777]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.2407,  0.1195, -0.6558,  ...,  0.5752,  0.2583,  0.2438],
        [ 0.0957,  0.1669,  0.0668,  ...,  0.4548,  0.0027,  0.0384],
        [ 0.0965,  0.2388, -0.4893,  ...,  0.8975,  0.0834,  0.5513],
        [-0.0767,  0.0351, -0.5068,  ...,  0.9058, -0.2737, -0.0302]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Loss after 870 Batch is nan 
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBac

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.2067,  0.0024, -0.4167,  ...,  1.1807,  0.0260, -0.0302],
        [ 0.2084,  0.0446, -0.1805,  ...,  0.5938,  0.0497, -0.2397],
        [ 0.4399,  0.8149, -0.3606,  ...,  0.8086, -0.0776, -0.0558],
        [ 0.3643,  0.1337, -0.2385,  ...,  0.5664, -0.2048,  0.4526]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.0941, -0.1255, -0.5566,  ...,  0.7607,  0.3743,  0.1879],
        [ 0.2649,  0.1412, -0.3582,  ...,  0.7305, -0.1538,  0.2976],
        [ 0.4478,  0.4846, -0.3994,  ...,  0.9663,  0.1638, -0.0964],
        [ 0.1997,  0.2759, -0.3101,  ...,  0.6772,  0.1681, -0.0308]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.1390,  0.2468, -0.2607,  ...,  0.9414, -0.2067, -0.0486],
        [ 0.3118, -0.1333, -0.6079,  ...,  0.5605, -0.0831,  0.1638],
        [ 0.2180, -0.0209, -0.2300,  ...,  0.9346, -0.1742,  0.2446],
        [ 0.2861,  0.3010, -0.3164,  ...,  0.6074, -0.3079, -0.1488]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[-0.0018,  0.1111, -0.3796,  ...,  0.9185,  0.2258, -0.0948],
        [ 0.5015,  0.4766, -0.2162,  ...,  0.8574, -0.1190,  0.0991],
        [ 0.2010, -0.2286, -0.4641,  ...,  0.8340,  0.0862,  0.2126],
        [ 0.0785, -0.0446, -0.5112,  ...,  1.0088, -0.3748,  0.1036]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.0919, -0.1011, -0.1171,  ...,  0.7178,  0.0243,  0.0160],
        [ 0.5132,  0.1683, -0.4453,  ...,  0.8608, -0.0941,  0.1276],
        [ 0.1948, -0.0145, -0.3948,  ...,  0.7437,  0.0598,  0.2002],
        [ 0.2280, -0.0061, -0.2563,  ...,  0.4917, -0.3557,  0.0580]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Loss after 930 Batch is nan 
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBac

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.3652,  0.1111, -0.3191,  ...,  0.3748, -0.1414,  0.0433],
        [ 0.6880, -0.0745, -0.3887,  ...,  0.8491,  0.2087,  0.0517],
        [ 0.0356,  0.1210, -0.3823,  ...,  0.9512, -0.0396, -0.1129],
        [ 0.2052, -0.1532, -0.2622,  ...,  0.6362,  0.0876,  0.2211]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.1116,  0.0607, -0.1670,  ...,  0.8750, -0.1301,  0.0090],
        [ 0.0135,  0.1388, -0.4924,  ...,  1.1689,  0.1265, -0.0415],
        [ 0.2385,  0.3706, -0.3428,  ...,  0.6548, -0.0462,  0.0305],
        [ 0.0164, -0.0412, -0.1417,  ...,  0.9185, -0.3628,  0.1046]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[-0.1831,  0.2432, -0.4058,  ...,  0.7031, -0.2474,  0.3455],
        [-0.1899, -0.0559, -0.3962,  ...,  0.9487, -0.3875,  0.1375],
        [ 0.8047,  0.1791, -0.2866,  ...,  0.9170,  0.4443,  0.1633],
        [ 0.4189,  0.4292, -0.3765,  ...,  0.6724,  0.2688,  0.0275]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.4541,  0.1323, -0.4368,  ...,  0.6060,  0.1405,  0.0811],
        [-0.1689,  0.2058, -0.3311,  ...,  0.7827, -0.1018,  0.3530],
        [ 0.2869,  0.2177, -0.4993,  ...,  0.3281, -0.1109,  0.2184],
        [ 0.1908,  0.2842, -0.1757,  ...,  0.6489, -0.0080, -0.0594]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.5068,  0.4138, -0.6113,  ...,  0.7129,  0.0685, -0.0160],
        [-0.1611,  0.1490, -0.3198,  ...,  0.8579, -0.1359,  0.0931],
        [ 0.3289,  0.1208, -0.3215,  ...,  0.9771,  0.0316,  0.4907],
        [ 0.2472,  0.0294, -0.1305,  ...,  0.5400, -0.2090, -0.0059]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Loss after 990 Batch is nan 
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBac

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.3413, -0.0358, -0.1849,  ...,  0.8262, -0.1010,  0.3572],
        [ 0.1006,  0.0988, -0.1833,  ...,  1.0537,  0.0810, -0.0624],
        [ 0.3823,  0.1390, -0.1372,  ...,  0.4893, -0.1152,  0.2469],
        [ 0.4277,  0.2184, -0.4343,  ...,  0.7163,  0.0378,  0.0327]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.5093,  0.1586, -0.3635,  ...,  0.6387,  0.1090, -0.1021],
        [ 0.2500,  0.4287, -0.2820,  ...,  0.4346, -0.1133, -0.2401],
        [ 0.3723,  0.0357, -0.2137,  ...,  0.6523, -0.0421,  0.3469],
        [ 0.1582,  0.2583, -0.1009,  ...,  0.8506,  0.1940,  0.1080]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.4927,  0.3459, -0.0690,  ...,  0.3206, -0.0783,  0.2979],
        [ 0.3320,  0.1357, -0.2393,  ...,  0.5542,  0.2080,  0.2079],
        [ 0.4653,  0.2379, -0.3533,  ...,  0.4407, -0.0326,  0.0623],
        [-0.0789,  0.2776, -0.1669,  ...,  0.6724, -0.0822,  0.1573]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.1189,  0.2220, -0.3918,  ...,  0.8232, -0.1897,  0.1389],
        [ 0.2346,  0.3845, -0.4370,  ...,  0.5967,  0.1729, -0.0675],
        [ 0.0661,  0.2438, -0.3374,  ...,  0.7490, -0.1284,  0.1620],
        [-0.0443,  0.2087, -0.0653,  ...,  0.9414, -0.0087,  0.0602]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.3049,  0.0015, -0.5244,  ...,  0.8325, -0.1030,  0.4268],
        [ 0.2852, -0.2830, -0.4995,  ...,  0.7285,  0.1558,  0.0874],
        [ 0.3896,  0.0578, -0.2656,  ...,  0.7529,  0.1896, -0.1768],
        [-0.0897,  0.0389,  0.0089,  ...,  0.6001, -0.3120,  0.0373]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Loss after 1050 Batch is nan 
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBa

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 3.4302e-01,  2.6758e-01, -2.9468e-01,  ...,  3.4375e-01,
         -9.0088e-02, -1.0811e-02],
        [ 1.4905e-01,  8.8440e-02, -5.2881e-01,  ...,  8.1299e-01,
         -4.4365e-03, -7.0557e-02],
        [ 5.9013e-03,  1.7542e-01, -5.4736e-01,  ...,  9.2090e-01,
         -3.4404e-04,  1.6870e-01],
        [ 2.4011e-01, -1.6870e-01, -4.4189e-01,  ...,  7.7197e-01,
          5.2856e-02, -2.5238e-02]], device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [n

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[-1.1224e-01, -1.6553e-01, -5.6738e-01,  ...,  7.1191e-01,
         -1.6113e-01,  6.6467e-02],
        [ 1.1780e-01, -2.5439e-01, -6.2402e-01,  ...,  9.0625e-01,
         -1.6907e-01,  1.1676e-01],
        [ 3.2324e-01, -3.8025e-02, -3.6719e-01,  ...,  5.7715e-01,
          3.6311e-04,  2.2424e-01],
        [ 3.0908e-01,  1.4563e-01, -3.2812e-01,  ...,  4.8096e-01,
          1.8518e-01,  4.7827e-01]], device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [n

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.4414, -0.0045, -0.2354,  ...,  0.8521,  0.0817, -0.1815],
        [ 0.2274,  0.0558, -0.2722,  ...,  0.9507,  0.1180, -0.0919],
        [ 0.2600,  0.0258, -0.0921,  ...,  0.3647, -0.0706, -0.0474],
        [-0.0251,  0.0732, -0.2595,  ...,  1.0967, -0.1543, -0.0126]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.3289,  0.2446, -0.1057,  ...,  0.6157,  0.0822, -0.1351],
        [ 0.1394,  0.0049, -0.0677,  ...,  0.5098, -0.0935, -0.1209],
        [-0.1910,  0.3499, -0.2578,  ...,  0.9424, -0.4263, -0.2330],
        [ 0.4407, -0.0577, -0.2605,  ...,  0.8589,  0.1597,  0.1556]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.3506, -0.0602, -0.2546,  ...,  0.4419, -0.2341,  0.3638],
        [-0.0241, -0.1476, -0.4346,  ...,  1.0596, -0.0168,  0.1285],
        [ 0.2878,  0.0768, -0.1332,  ...,  0.3838,  0.3547,  0.2517],
        [ 0.2734,  0.0607, -0.2937,  ...,  0.4695, -0.1228,  0.0396]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Loss after 1110 Batch is nan 
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBa

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.2278,  0.1375, -0.2903,  ...,  0.7793, -0.1722, -0.1007],
        [ 0.2086,  0.1033, -0.2856,  ...,  0.7510,  0.1515, -0.0752],
        [ 0.2372,  0.2075, -0.6211,  ...,  0.5024, -0.0301,  0.0249],
        [ 0.3047,  0.0881, -0.2627,  ...,  0.8433,  0.2306,  0.1116]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.2834,  0.4490, -0.2374,  ...,  0.8647, -0.2739,  0.1438],
        [ 0.5439,  0.2622, -0.6616,  ...,  0.7700,  0.1852, -0.0855],
        [ 0.3557,  0.0742, -0.4424,  ...,  1.0381, -0.0280,  0.3694],
        [ 0.2656,  0.3196, -0.1902,  ...,  0.5391,  0.0431,  0.1180]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.5649,  0.0078, -0.3552,  ...,  0.5708, -0.0715,  0.0254],
        [ 0.0420, -0.2188, -0.1159,  ...,  0.6694, -0.2115,  0.0996],
        [ 0.1881, -0.0251, -0.3269,  ...,  0.5171,  0.1247,  0.2358],
        [ 0.2817,  0.1639, -0.2783,  ...,  0.6714,  0.0087, -0.1177]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.3977,  0.0151, -0.2825,  ...,  0.5444,  0.0242,  0.1674],
        [-0.0828,  0.2155,  0.1808,  ...,  0.8311, -0.1896, -0.1583],
        [ 0.0139,  0.0798, -0.3997,  ...,  0.5820, -0.1410, -0.4648],
        [ 0.2715, -0.0785, -0.1599,  ...,  0.6055,  0.2764,  0.0715]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.0916, -0.1810, -0.5605,  ...,  0.8569, -0.1399,  0.0724],
        [ 0.1196, -0.0086, -0.2693,  ...,  0.6914, -0.1235, -0.1052],
        [ 0.0879,  0.4963, -0.4683,  ...,  0.5781, -0.1940, -0.1663],
        [ 0.0934,  0.7178, -0.0428,  ...,  0.7671, -0.0089,  0.1493]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Loss after 1170 Batch is nan 
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBa

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.2437,  0.6245, -0.3669,  ...,  0.8716, -0.1512, -0.0453],
        [ 0.0522, -0.1060,  0.2505,  ...,  0.9336, -0.0413,  0.0601],
        [ 0.0742,  0.1199, -0.3135,  ...,  0.5332,  0.2358,  0.3533],
        [ 0.2277,  0.2776, -0.3679,  ...,  0.7500, -0.0895, -0.0081]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.2461,  0.2068, -0.2162,  ...,  0.6113, -0.0012, -0.0640],
        [ 0.0967,  0.4050, -0.3359,  ...,  0.9995, -0.0341,  0.1093],
        [ 0.2988,  0.1873, -0.2869,  ...,  0.7046,  0.0551,  0.2438],
        [ 0.6245,  0.2489, -0.3501,  ...,  0.6709,  0.2510,  0.4529]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[-0.0161, -0.2046, -0.2296,  ...,  1.0596, -0.1798, -0.2068],
        [ 0.5957, -0.2338, -0.1256,  ...,  0.4607,  0.1232,  0.3970],
        [ 0.3486,  0.0023, -0.5347,  ...,  0.6152, -0.0601,  0.0429],
        [ 0.1060,  0.2532, -0.3088,  ...,  0.9116,  0.1477,  0.1349]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.3203,  0.5366, -0.2048,  ...,  0.7119, -0.2117, -0.2227],
        [ 0.0408,  0.1885, -0.0879,  ...,  0.9858, -0.0466,  0.1265],
        [-0.1130, -0.0995, -0.0480,  ...,  0.2622, -0.2698,  0.2419],
        [ 0.2355, -0.3770, -0.2200,  ...,  0.3813, -0.4055,  0.1976]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.2502,  0.4512, -0.4622,  ...,  0.7559,  0.0889,  0.3196],
        [ 0.0243, -0.3540, -0.4807,  ...,  0.6128,  0.0751,  0.0972],
        [ 0.2002, -0.1511, -0.0729,  ...,  0.8198, -0.0081, -0.2101],
        [ 0.1683, -0.0905, -0.0966,  ...,  0.6631, -0.1309,  0.2330]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Loss after 1230 Batch is nan 
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBa

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.1145,  0.2712, -0.3318,  ...,  1.0615, -0.1174,  0.1750],
        [-0.0415, -0.0089, -0.1116,  ...,  0.8931, -0.1149, -0.0457],
        [ 0.7046, -0.0495, -0.0444,  ...,  0.5181,  0.0146,  0.0289],
        [-0.0844,  0.1062, -0.1561,  ...,  0.7036, -0.1057,  0.1160]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.1648,  0.3867, -0.3381,  ...,  0.5986,  0.2095,  0.1451],
        [-0.3293,  0.2119, -0.3723,  ...,  1.2646, -0.1078, -0.0296],
        [ 0.2358,  0.3164, -0.3616,  ...,  0.2242, -0.0881,  0.0023],
        [ 0.3975,  0.0850, -0.1964,  ...,  0.9600, -0.2590,  0.5933]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.0284,  0.3318,  0.1486,  ...,  0.5903, -0.0326, -0.1263],
        [ 0.1814,  0.1183, -0.1691,  ...,  0.6421,  0.0398,  0.2546],
        [-0.0983,  0.2817, -0.0676,  ...,  0.6719, -0.3440,  0.0911],
        [ 0.2057,  0.0022,  0.0398,  ...,  0.3760, -0.0936,  0.3591]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[-0.0405,  0.2219, -0.3525,  ...,  0.9175,  0.3027,  0.0114],
        [ 0.2705, -0.5625, -0.3699,  ...,  0.8623,  0.2192,  0.4341],
        [ 0.3037,  0.1503, -0.3584,  ...,  0.8257, -0.1445,  0.0947],
        [ 0.2395,  0.1764, -0.4739,  ...,  0.8252, -0.0453, -0.0395]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.7124,  0.1874, -0.1686,  ...,  1.0225, -0.2034,  0.0451],
        [ 0.0797,  0.0819, -0.3916,  ...,  0.7700, -0.0111,  0.1926],
        [ 0.2415,  0.1768, -0.1970,  ...,  0.6172, -0.1387,  0.4797],
        [ 0.3936,  0.1399, -0.3674,  ...,  0.7607, -0.0914,  0.2246]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Loss after 1290 Batch is nan 
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBa

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.1881,  0.0243, -0.4087,  ...,  0.8345, -0.1970, -0.0248],
        [ 0.0500, -0.0028, -0.3081,  ...,  0.5146, -0.2563,  0.0287],
        [ 0.2578, -0.1539, -0.2952,  ...,  0.4739, -0.1929,  0.1417],
        [ 0.1903,  0.4277, -0.4573,  ...,  0.9019,  0.0648,  0.2178]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.1941,  0.1573, -0.3381,  ...,  0.6069, -0.1976,  0.1184],
        [ 0.2729, -0.0179, -0.3594,  ...,  0.9878,  0.1398,  0.2534],
        [ 0.0673,  0.0327, -0.4919,  ...,  0.4915,  0.0939, -0.0703],
        [ 0.3284,  0.0062, -0.2559,  ...,  0.6255,  0.0374,  0.1068]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.2239,  0.0309, -0.2952,  ...,  0.6343, -0.0131,  0.3096],
        [-0.0302, -0.0141, -0.3154,  ...,  0.7769, -0.0804,  0.6006],
        [-0.1753, -0.0053, -0.3989,  ...,  0.9058, -0.4172, -0.0338],
        [ 0.2998, -0.1819, -0.2385,  ...,  0.9175,  0.0304,  0.0831]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[-0.1081,  0.2209,  0.1017,  ...,  0.7764, -0.0864, -0.2081],
        [ 0.1321,  0.1501, -0.3035,  ...,  0.6572,  0.2224, -0.0828],
        [ 0.0840,  0.2622, -0.0668,  ...,  0.9741, -0.0602,  0.1216],
        [ 0.1001, -0.1490, -0.3904,  ...,  0.7246, -0.1608,  0.2222]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[-0.1193,  0.0703,  0.0478,  ...,  0.3276,  0.0323,  0.0845],
        [-0.3237, -0.1165, -0.4680,  ...,  0.3542, -0.3176,  0.3940],
        [ 0.1458, -0.1471, -0.1312,  ...,  0.5200, -0.1982,  0.1770],
        [ 0.2954, -0.0243, -0.3667,  ...,  0.7500,  0.0938,  0.1356]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Loss after 1350 Batch is nan 
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBa

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.2385, -0.0499, -0.5078,  ...,  0.7197,  0.1471,  0.3625],
        [ 0.3584, -0.1038, -0.1532,  ...,  0.5698, -0.3096,  0.1541],
        [ 0.0900,  0.1805, -0.3564,  ...,  0.6714, -0.0264, -0.1838],
        [ 0.1763, -0.2791, -0.5747,  ...,  0.6475,  0.0869,  0.3213]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.2450,  0.2808, -0.2673,  ...,  1.0078, -0.1853,  0.0670],
        [ 0.3567,  0.0734, -0.3184,  ...,  0.6729, -0.4089,  0.2524],
        [ 0.7241,  0.3564,  0.1014,  ...,  0.5527,  0.2438, -0.0852],
        [ 0.1081,  0.0583, -0.4534,  ...,  0.4272, -0.3323,  0.1302]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[-0.1818, -0.0778, -0.4419,  ...,  0.6445,  0.2458,  0.3862],
        [ 0.4473,  0.3037, -0.4929,  ...,  0.7847,  0.0867, -0.0190],
        [ 0.2062,  0.3799, -0.0975,  ...,  0.8384, -0.0092,  0.3242],
        [ 0.2783, -0.0042, -0.4792,  ...,  0.5684, -0.0723, -0.1373]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.3958, -0.0403, -0.3757,  ...,  0.8599, -0.3369, -0.0331],
        [ 0.3015,  0.2081, -0.2285,  ...,  0.6416, -0.0508,  0.0227],
        [ 0.1281, -0.0986, -0.2837,  ...,  0.6929, -0.0625,  0.0862],
        [-0.0600, -0.0647, -0.3396,  ...,  0.7266,  0.1525,  0.2416]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.0630,  0.0293, -0.4353,  ...,  0.7500, -0.3630,  0.4797],
        [ 0.0886, -0.0561, -0.3320,  ...,  0.5122, -0.1130,  0.0159],
        [ 0.1307, -0.1981, -0.1225,  ...,  0.6240, -0.4016,  0.0976],
        [ 0.3665,  0.0817, -0.4463,  ...,  0.4958, -0.1403,  0.3127]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Loss after 1410 Batch is nan 
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBa

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.1526,  0.3831, -0.4302,  ...,  0.8633, -0.1660,  0.0758],
        [ 0.1874,  0.2729, -0.4666,  ...,  0.3809, -0.1107, -0.0300],
        [ 0.0289, -0.1092, -0.3606,  ...,  0.9727,  0.1005,  0.0142],
        [ 0.5405,  0.3057, -0.1541,  ...,  0.6235,  0.1404, -0.2078]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.2722, -0.0498, -0.5596,  ...,  0.8643, -0.1017,  0.1898],
        [ 0.2444,  0.2983, -0.2563,  ...,  0.6782,  0.1230, -0.1510],
        [ 0.4441,  0.0108, -0.3887,  ...,  0.7622,  0.2090,  0.2629],
        [ 0.3958, -0.1329, -0.3630,  ...,  0.4287,  0.0500, -0.1018]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.0130,  0.0447, -0.2025,  ...,  0.4685, -0.2874,  0.0865],
        [ 0.0271,  0.2947, -0.1256,  ...,  0.8022,  0.1821, -0.4077],
        [ 0.0673,  0.2017, -0.1873,  ...,  1.1172, -0.3298,  0.4104],
        [ 0.2571,  0.6455, -0.5405,  ...,  0.6294, -0.1646, -0.1617]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.6138, -0.1761, -0.2095,  ...,  0.6270,  0.2213,  0.4373],
        [ 0.2125,  0.1797, -0.1176,  ...,  0.5342,  0.2296,  0.4714],
        [ 0.0350, -0.0945, -0.3428,  ...,  0.4473, -0.1206,  0.0528],
        [ 0.0441,  0.1090, -0.4521,  ...,  0.6616, -0.3740,  0.0759]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.2468,  0.3660, -0.0888,  ...,  0.7910,  0.0360, -0.1078],
        [ 0.2908, -0.4580, -0.4758,  ...,  0.6245,  0.2131,  0.0286],
        [ 0.3662,  0.0551, -0.3850,  ...,  0.8667, -0.1382, -0.1997],
        [ 0.2180, -0.0802,  0.0389,  ...,  0.4075, -0.2725,  0.3203]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.0589, -0.3232, -0.0613,  ...,  0.4294,  0.0721,  0.2639],
        [ 0.0533, -0.0424, -0.3547,  ...,  1.0186,  0.2563,  0.2179],
        [-0.1541,  0.4541, -0.4590,  ...,  0.5181, -0.1578,  0.4670],
        [ 0.5664, -0.0114, -0.0801,  ...,  0.7334, -0.1694,  0.0771]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Loss after 1490 Batch is nan 
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBa

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 2.6245e-01, -1.3879e-01, -2.5513e-01,  ...,  6.3574e-01,
         -5.7739e-02, -3.1710e-04],
        [-1.3477e-01,  3.6194e-02, -3.7231e-01,  ...,  8.8574e-01,
          1.9360e-01, -4.0991e-01],
        [ 4.0869e-01,  1.7249e-01, -4.9634e-01,  ...,  7.4561e-01,
         -9.2712e-02,  1.2451e-01],
        [ 6.9702e-02, -2.2803e-01, -2.2729e-01,  ...,  6.8115e-01,
          6.7871e-02,  6.2256e-01]], device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [n

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.1479,  0.2854, -0.2354,  ...,  0.7988, -0.0809, -0.0320],
        [ 0.3538,  0.0981, -0.1686,  ...,  0.6426,  0.2812,  0.4529],
        [ 0.2632,  0.3630,  0.0741,  ...,  0.8384,  0.3171, -0.2340],
        [ 0.3298,  0.0091, -0.5342,  ...,  0.6831, -0.1434, -0.0154]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Loss after 1530 Batch is nan 
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.0731,  0.1967, -0.2837,  ...,  0.7617,  0.0145, -0.0975],
        [ 0.0806, -0.0095, -0.2969,  ...,  0.7168, -0.0925,  0.2170],
        [-0.1538, -0.1138, -0.1055,  ...,  0.7021, -0.1527, -0.2544],
        [ 0.3699,  0.1335, -0.3220,  ...,  0.7222, -0.3152,  0.0605]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBa

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[-0.0275, -0.0945, -0.1656,  ...,  1.0283, -0.1637, -0.0997],
        [-0.2323,  0.3105, -0.1285,  ...,  0.6890, -0.2440,  0.2881],
        [ 0.5044,  0.1722, -0.2286,  ...,  0.3516, -0.2959,  0.4043],
        [-0.1823,  0.2358, -0.1032,  ...,  0.5205, -0.2976, -0.0028]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.1381,  0.1774, -0.0720,  ...,  0.4006, -0.3057,  0.3831],
        [ 0.0487,  0.3596, -0.2292,  ...,  0.5029,  0.1160, -0.1058],
        [ 0.3350,  0.1566, -0.1272,  ...,  0.8076, -0.1753,  0.2166],
        [ 0.2644,  0.0612, -0.3660,  ...,  0.6973,  0.3167, -0.1599]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Loss after 1560 Batch is nan 
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBa

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.0529,  0.0499, -0.2009,  ...,  0.7754, -0.1698,  0.0762],
        [ 0.4402,  0.4468, -0.6582,  ...,  0.6963,  0.2416,  0.0090],
        [-0.0545,  0.0334, -0.3860,  ...,  0.6875,  0.1487,  0.2725],
        [ 0.0618, -0.1182, -0.4382,  ...,  0.7949, -0.0910,  0.2800]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.4216,  0.0130, -0.5391,  ...,  0.1735,  0.0138, -0.0521],
        [ 0.5391,  0.2664, -0.2678,  ...,  0.5815,  0.0872, -0.0645],
        [ 0.2435, -0.1346, -0.5552,  ...,  0.6782,  0.0295,  0.0146],
        [ 0.2864,  0.1028, -0.4097,  ...,  0.7231, -0.3521,  0.1305]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.4312,  0.1448, -0.3933,  ...,  0.5425,  0.1201,  0.0617],
        [-0.2991,  0.1819, -0.3152,  ...,  0.7681,  0.0161,  0.0681],
        [ 0.3113, -0.0905, -0.0732,  ...,  0.6938, -0.1703, -0.0550],
        [ 0.0641, -0.0411, -0.3870,  ...,  0.9111, -0.2152,  0.1730]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[-0.0179,  0.0572, -0.1487,  ...,  0.7246, -0.4023,  0.2466],
        [ 0.2089,  0.1465,  0.0121,  ...,  0.7788, -0.2959, -0.1055],
        [-0.0257,  0.0509, -0.1715,  ...,  0.6118, -0.2729,  0.0493],
        [ 0.3625,  0.3755, -0.2000,  ...,  0.5811,  0.0193,  0.0333]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 2.0959e-01, -9.2459e-04, -2.5732e-01,  ...,  8.1445e-01,
         -3.5278e-01, -2.5146e-01],
        [ 1.3940e-01,  3.6475e-01, -1.9141e-01,  ...,  8.2129e-01,
         -6.2525e-05, -9.4116e-02],
        [ 1.0132e-01,  1.0358e-01, -2.8882e-01,  ...,  3.7769e-01,
          7.9193e-03, -1.2109e-01],
        [ 3.1708e-02, -1.4404e-01, -3.5449e-01,  ...,  5.5762e-01,
         -1.0522e-01,  5.5023e-02]], device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Loss after 1620 Batch is nan 
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  .

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.2289, -0.2465,  0.1692,  ...,  0.3757, -0.1218,  0.0742],
        [ 0.0320,  0.1963, -0.1207,  ...,  0.9927,  0.0542,  0.0651],
        [ 0.2732,  0.4453, -0.1770,  ...,  0.3088,  0.2666,  0.1815],
        [-0.0063, -0.0990, -0.1769,  ...,  0.8926,  0.2163,  0.4529]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.1709,  0.1587, -0.5439,  ...,  0.5454, -0.1193,  0.1653],
        [ 0.5439,  0.2693, -0.1362,  ...,  0.4785, -0.0238,  0.4214],
        [-0.0158,  0.1869, -0.4150,  ...,  1.0811, -0.0519, -0.0717],
        [ 0.2419,  0.1666, -0.2084,  ...,  0.8550,  0.1185,  0.2456]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[-0.0322, -0.0206, -0.4321,  ...,  0.7583,  0.0095, -0.1897],
        [ 0.0381, -0.4048, -0.1852,  ...,  0.5757,  0.0789,  0.2681],
        [ 0.3965,  0.2515, -0.2185,  ...,  0.5303, -0.2515, -0.1758],
        [ 0.4119,  0.0475, -0.3081,  ...,  0.8779, -0.2188,  0.0952]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.0216,  0.2456, -0.5439,  ...,  0.9746,  0.1154, -0.0261],
        [ 0.2202,  0.0781,  0.0486,  ...,  0.3376,  0.1785, -0.3962],
        [-0.1173,  0.0764, -0.2260,  ...,  0.7871, -0.1223, -0.0988],
        [ 0.2676,  0.1217, -0.3518,  ...,  0.7402, -0.2310, -0.0475]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.1066,  0.0407, -0.4673,  ...,  0.3926, -0.1222,  0.4312],
        [ 0.2198,  0.3137, -0.3765,  ...,  0.6069,  0.2031,  0.0988],
        [ 0.0665,  0.1451, -0.1426,  ...,  0.9443, -0.0416,  0.0480],
        [ 0.0415,  0.0740, -0.1168,  ...,  0.6123, -0.2698,  0.1876]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Loss after 1680 Batch is nan 
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBa

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.1207, -0.0424, -0.1964,  ...,  0.6782, -0.0874,  0.0986],
        [ 0.2059,  0.4622, -0.0237,  ...,  0.6284,  0.0878,  0.2634],
        [ 0.3687, -0.0240, -0.0628,  ...,  0.5508,  0.1165, -0.1129],
        [-0.0177, -0.1082, -0.1786,  ...,  0.6147,  0.0302,  0.1251]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.0172,  0.2075, -0.2468,  ...,  0.8232,  0.1541,  0.2423],
        [ 0.2433, -0.0837, -0.3955,  ...,  0.4802, -0.0433,  0.4834],
        [ 0.3464,  0.0759, -0.3713,  ...,  0.9087, -0.0954,  0.1448],
        [ 0.4365,  0.4531, -0.2734,  ...,  0.6138, -0.2323,  0.0744]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.3669, -0.0475, -0.5825,  ...,  0.7910, -0.2766, -0.4602],
        [ 0.3247, -0.0476, -0.3259,  ...,  0.4863, -0.0370, -0.0192],
        [-0.0385,  0.0032, -0.3279,  ...,  0.8252, -0.2837,  0.2441],
        [ 0.1989, -0.3291, -0.0531,  ...,  0.2197,  0.3135,  0.1521]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.1284,  0.1428, -0.2517,  ...,  0.5962,  0.1388, -0.0716],
        [-0.0658,  0.3384, -0.5449,  ...,  0.7930,  0.1482, -0.0190],
        [-0.2959, -0.1464, -0.2214,  ...,  0.7910, -0.2664, -0.0173],
        [ 0.1841, -0.1508, -0.1429,  ...,  0.7715, -0.0338,  0.0640]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.2664,  0.1191, -0.4578,  ...,  0.7466, -0.0579,  0.0225],
        [ 0.3647,  0.2051, -0.0073,  ...,  0.5942,  0.1355, -0.0280],
        [ 0.5801,  0.0175, -0.3418,  ...,  0.5581, -0.2925,  0.4189],
        [ 0.0927, -0.2100, -0.5186,  ...,  1.0137, -0.1041, -0.0908]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Loss after 1740 Batch is nan 
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBa

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[-0.2460,  0.2695, -0.1431,  ...,  0.5488, -0.3479, -0.0958],
        [ 0.4932,  0.3931, -0.6001,  ...,  0.7847, -0.0019,  0.2440],
        [ 0.4050, -0.0509, -0.3691,  ...,  0.6313, -0.1018,  0.0134],
        [ 0.1732,  0.1481, -0.5298,  ...,  0.6245, -0.1917, -0.0129]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.5034,  0.1744, -0.3796,  ...,  0.7661,  0.1015,  0.2245],
        [ 0.5542,  0.1353, -0.2991,  ...,  0.5835,  0.1626, -0.0490],
        [-0.0955, -0.0905, -0.3906,  ...,  1.1709, -0.1227,  0.0363],
        [-0.0294, -0.1158, -0.2847,  ...,  0.7612,  0.0075,  0.0214]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.2644, -0.1765, -0.1262,  ...,  0.7759,  0.0292,  0.3477],
        [ 0.4365,  0.1064, -0.4102,  ...,  0.3914,  0.1008,  0.1266],
        [-0.0816, -0.0423,  0.0110,  ...,  0.5234, -0.4126,  0.0927],
        [ 0.2094,  0.4790, -0.1528,  ...,  0.9043, -0.0917, -0.2458]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.2480,  0.3574, -0.2615,  ...,  0.9668,  0.0672,  0.1858],
        [ 0.3035,  0.6699, -0.6216,  ...,  0.6641,  0.0583, -0.0612],
        [ 0.3481,  0.2444, -0.0049,  ...,  0.8140,  0.2405,  0.0122],
        [ 0.0904, -0.4055, -0.0432,  ...,  0.6025,  0.1636,  0.3647]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.4604,  0.5767, -0.1067,  ...,  0.7417,  0.1415,  0.5112],
        [ 0.0508,  0.1851, -0.3796,  ...,  0.8003, -0.0876, -0.0820],
        [ 0.6118, -0.1064, -0.0291,  ...,  0.7168,  0.2305,  0.2009],
        [ 0.3474,  0.4170, -0.2786,  ...,  0.7310,  0.0130,  0.0588]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Loss after 1800 Batch is nan 
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBa

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.0839, -0.1759, -0.3206,  ...,  0.7979,  0.2869,  0.1957],
        [ 0.3945, -0.0467, -0.0838,  ...,  1.0518,  0.3210, -0.1519],
        [ 0.1334,  0.0181, -0.0977,  ...,  0.8438, -0.0057,  0.1473],
        [ 0.3875,  0.1641, -0.0541,  ...,  0.5874, -0.2108, -0.3074]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.1742,  0.2311, -0.1515,  ...,  0.6577, -0.1871, -0.0985],
        [ 0.6519,  0.2487, -0.4690,  ...,  0.9277,  0.0675,  0.2056],
        [ 0.2000,  0.0730, -0.3330,  ...,  0.6543,  0.0908, -0.1481],
        [ 0.0195,  0.0024, -0.2852,  ...,  0.2681,  0.0595,  0.0174]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.1613,  0.1506, -0.1842,  ...,  0.6489,  0.1302,  0.5669],
        [ 0.4089,  0.1230, -0.3850,  ...,  0.5474, -0.3484,  0.0818],
        [ 0.1877, -0.2168, -0.0934,  ...,  0.1493,  0.1846,  0.0185],
        [ 0.3438,  0.0973, -0.2932,  ...,  0.6841, -0.2605,  0.3469]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[-0.2118,  0.1294, -0.4851,  ...,  0.8936, -0.2069,  0.0282],
        [ 0.2307,  0.0858, -0.1772,  ...,  1.0020,  0.1103,  0.0083],
        [ 0.0782,  0.1763, -0.3147,  ...,  0.2346, -0.2981,  0.5874],
        [ 0.3311,  0.3196, -0.5620,  ...,  0.2358,  0.0921,  0.4360]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.4229,  0.1236, -0.3977,  ...,  0.7759,  0.0096,  0.2502],
        [ 0.2072, -0.0502, -0.1436,  ...,  0.5728, -0.2308,  0.0402],
        [ 0.4773,  0.1810, -0.7598,  ...,  0.5449,  0.4363,  0.1978],
        [ 0.2389, -0.1506,  0.0462,  ...,  0.4290, -0.0502,  0.0476]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Loss after 1860 Batch is nan 
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBa

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.2086,  0.1782, -0.5806,  ...,  0.9282, -0.2090,  0.1036],
        [ 0.2546,  0.0268, -0.2700,  ...,  0.9028,  0.1326, -0.1605],
        [ 0.4829,  0.1299, -0.4006,  ...,  0.7412,  0.4485, -0.1455],
        [ 0.6011,  0.0784, -0.4285,  ...,  0.5522, -0.0338,  0.3462]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.5767,  0.0768, -0.2260,  ...,  0.8335,  0.1581,  0.0602],
        [-0.0081,  0.2847, -0.3801,  ...,  0.5381, -0.0360, -0.2832],
        [ 0.6240, -0.3176,  0.0200,  ..., -0.0693, -0.1196,  0.1891],
        [ 0.4355,  0.0956, -0.2369,  ...,  0.9331, -0.1584,  0.1967]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.3479,  0.0467, -0.3140,  ...,  0.4817, -0.0709,  0.4727],
        [-0.1398,  0.1194, -0.2363,  ...,  0.7104,  0.1027, -0.0292],
        [ 0.1469,  0.3254, -0.1835,  ...,  0.6411, -0.0107, -0.0373],
        [ 0.0544, -0.0464, -0.2170,  ...,  0.7231, -0.1393,  0.1462]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.2053, -0.0322, -0.2627,  ...,  0.8442,  0.0903,  0.1897],
        [ 0.4011, -0.0842, -0.5767,  ...,  0.4736,  0.4285,  0.0362],
        [ 0.3069,  0.0536, -0.0857,  ...,  0.3250, -0.0044,  0.1080],
        [ 0.0803, -0.2683, -0.4873,  ...,  0.8223,  0.0133,  0.3040]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.2793,  0.1578, -0.2993,  ...,  0.9121,  0.1400,  0.2273],
        [ 0.2290, -0.0539, -0.4346,  ...,  0.8511,  0.2749,  0.2493],
        [-0.2399, -0.0668, -0.3096,  ...,  1.0771, -0.2849,  0.0999],
        [ 0.3850, -0.0575, -0.2659,  ...,  0.6626, -0.1198,  0.2010]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Loss after 1920 Batch is nan 
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBa

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 3.2129e-01,  5.8008e-01, -3.1470e-01,  ...,  5.6348e-01,
         -3.8037e-01,  1.4014e-01],
        [ 1.6150e-01,  3.4821e-02, -1.8958e-01,  ...,  6.1719e-01,
         -1.2103e-01,  1.4099e-01],
        [ 1.1835e-01, -2.2510e-01, -2.6099e-01,  ...,  8.5254e-01,
          5.1392e-02, -5.7983e-04],
        [ 1.8469e-01,  2.4792e-01, -4.0601e-01,  ...,  1.3359e+00,
         -2.3193e-02,  1.5002e-01]], device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [n

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.1240, -0.0349, -0.2961,  ...,  0.5229, -0.1226,  0.4258],
        [-0.0046,  0.2184, -0.3499,  ...,  0.9019,  0.0614, -0.0514],
        [ 0.1921,  0.2529, -0.1026,  ...,  0.7466,  0.2659,  0.1272],
        [ 0.0775,  0.2019, -0.3752,  ...,  0.8130, -0.1439,  0.0457]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.2327,  0.2737, -0.3035,  ...,  0.7822, -0.2561,  0.0514],
        [ 0.4150,  0.1960, -0.3879,  ...,  0.8804, -0.0526,  0.0668],
        [ 0.2286,  0.2244, -0.0734,  ...,  0.6455, -0.2959,  0.0107],
        [ 0.4224,  0.0400, -0.4077,  ...,  0.5811,  0.2194, -0.0458]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.0979, -0.0031, -0.4880,  ...,  0.7866, -0.0173,  0.3430],
        [ 0.4897,  0.3174, -0.1652,  ...,  0.5947, -0.0080, -0.0561],
        [ 0.2954,  0.0793, -0.2764,  ...,  0.1376, -0.1929,  0.2539],
        [ 0.1460, -0.2194, -0.6665,  ...,  0.8652,  0.1212,  0.4651]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 1.8372e-02,  3.9380e-01,  9.9915e-02,  ...,  5.4053e-01,
          1.3818e-01, -2.1997e-01],
        [ 5.9668e-01,  1.1432e-01, -4.6826e-01,  ...,  3.3130e-01,
          6.7759e-04,  3.1860e-01],
        [-3.4210e-02, -1.2695e-01, -3.0518e-01,  ...,  7.1924e-01,
         -1.1774e-01,  2.8101e-01],
        [ 4.9878e-01, -1.8640e-01, -6.0498e-01,  ...,  8.1982e-01,
          1.7871e-01,  4.2334e-01]], device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Loss after 1980 Batch is nan 
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  .

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.2820,  0.1179, -0.0824,  ...,  0.5132,  0.3379, -0.1256],
        [ 0.2018,  0.3123, -0.2209,  ...,  1.0801, -0.2070,  0.0525],
        [-0.0299,  0.3499,  0.0881,  ...,  0.3569, -0.4109,  0.1664],
        [ 0.4614,  0.4255, -0.0169,  ...,  0.4087,  0.3164,  0.6655]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.1935,  0.0204, -0.3523,  ...,  0.5190, -0.1360,  0.1997],
        [ 0.2610, -0.0917, -0.5156,  ...,  0.9160, -0.0401, -0.0273],
        [ 0.2505,  0.2510, -0.4133,  ...,  0.7983,  0.0613, -0.0046],
        [ 0.4377,  0.0201, -0.1069,  ...,  0.7676, -0.0928,  0.2249]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.1278,  0.5361, -0.4539,  ...,  0.8716,  0.2042,  0.0151],
        [ 0.2234,  0.0950, -0.2430,  ...,  0.6187, -0.2725,  0.0090],
        [ 0.0550,  0.1365, -0.6499,  ...,  0.4890,  0.2054, -0.3452],
        [ 0.1946,  0.2769, -0.3105,  ...,  0.5190, -0.0848,  0.1417]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.3999, -0.0403, -0.2212,  ...,  0.4607,  0.0908,  0.2073],
        [ 0.5054,  0.0029, -0.1050,  ...,  0.4619, -0.1254,  0.1906],
        [ 0.3337,  0.2651, -0.3042,  ...,  0.3887,  0.3604, -0.0771],
        [ 0.5269,  0.2382, -0.5952,  ...,  0.7227, -0.1641,  0.2153]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.4526,  0.0285, -0.2788,  ...,  1.0654,  0.0521,  0.3235],
        [ 0.0831,  0.1044, -0.1642,  ...,  0.9341, -0.2791, -0.0276],
        [ 0.1254,  0.2834, -0.5371,  ...,  0.5327,  0.0848, -0.1713],
        [ 0.2458,  0.2278, -0.3201,  ...,  1.0068, -0.0544,  0.1794]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Loss after 2040 Batch is nan 
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBa

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.4004,  0.1595, -0.1997,  ...,  0.9932,  0.0320,  0.1012],
        [ 0.1221,  0.0852, -0.6479,  ...,  0.4651, -0.1426,  0.2339],
        [ 0.2288,  0.1591, -0.3127,  ...,  0.5464, -0.1252, -0.0012],
        [ 0.3865, -0.0605, -0.3352,  ...,  0.8115, -0.0029,  0.0239]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.1593, -0.0499, -0.3882,  ...,  0.6187, -0.0502,  0.4585],
        [ 0.3176,  0.3713, -0.1500,  ...,  0.6816, -0.0326,  0.2137],
        [ 0.1758, -0.1096, -0.0015,  ...,  0.5820,  0.1097,  0.1410],
        [ 0.3591,  0.1040, -0.2340,  ...,  0.8691,  0.0152, -0.0327]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.2190,  0.0109, -0.4036,  ...,  0.9141,  0.0706,  0.0784],
        [ 0.5742,  0.2192, -0.3970,  ...,  0.6387, -0.0135,  0.0031],
        [-0.1631, -0.2164, -0.2795,  ...,  0.7153, -0.1569,  0.2522],
        [ 0.3469,  0.4729, -0.5000,  ...,  0.5386,  0.0237,  0.0762]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.2654, -0.1663, -0.6890,  ...,  0.7241,  0.1871,  0.3560],
        [ 0.2588, -0.1870, -0.4836,  ...,  0.4976, -0.2705,  0.1041],
        [-0.1263, -0.0565, -0.2375,  ...,  0.7607,  0.0660,  0.0201],
        [ 0.3357,  0.0052, -0.6387,  ...,  0.7173, -0.2067,  0.2439]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.3240,  0.3604, -0.1465,  ...,  0.7065, -0.2266, -0.1619],
        [ 0.3511,  0.1249, -0.2891,  ...,  0.7900, -0.0793,  0.1377],
        [ 0.2961,  0.2330, -0.3655,  ...,  0.8618, -0.0547,  0.1067],
        [ 0.0495,  0.3127, -0.4497,  ...,  0.7573, -0.0282,  0.2050]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Loss after 2100 Batch is nan 
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBa

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 2.7856e-01,  4.4067e-01, -2.8613e-01,  ...,  8.7158e-01,
         -3.6670e-01, -1.5491e-01],
        [ 1.8591e-01,  8.1444e-04,  6.8848e-02,  ...,  7.1240e-01,
         -2.1741e-01, -1.9177e-01],
        [-7.4234e-03,  1.9519e-01, -2.9834e-01,  ...,  8.2520e-01,
         -7.1533e-02, -1.1401e-01],
        [ 2.7563e-01,  5.4102e-01, -3.7646e-01,  ...,  4.6899e-01,
         -3.0151e-02,  1.9348e-01]], device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [n

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.2009,  0.0352, -0.0524,  ...,  1.3926,  0.0925,  0.2595],
        [ 0.4041,  0.2194, -0.5713,  ...,  0.8447,  0.0942,  0.1172],
        [ 0.2245,  0.1649, -0.3787,  ...,  0.6587, -0.2301,  0.0704],
        [ 0.2678,  0.1881, -0.4858,  ...,  1.0215,  0.1511,  0.2489]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.1490, -0.1558, -0.1959,  ...,  0.8721,  0.1353,  0.1357],
        [ 0.3083,  0.2194, -0.4521,  ...,  0.5151, -0.1982, -0.0766],
        [ 0.2937,  0.4407, -0.0655,  ...,  0.5513, -0.1138, -0.1556],
        [ 0.2822,  0.1677, -0.3486,  ...,  0.6611,  0.0289,  0.3406]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.4082,  0.2103, -0.2181,  ...,  0.6572,  0.1940, -0.0975],
        [ 0.1158,  0.2708, -0.2269,  ...,  1.0283, -0.1052, -0.3191],
        [ 0.2542, -0.1357, -0.1798,  ...,  0.5127, -0.1373,  0.2756],
        [ 0.5298,  0.5273, -0.1121,  ...,  0.2141,  0.2434,  0.4851]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.0750,  0.2322, -0.3804,  ...,  0.6250,  0.3669, -0.1617],
        [ 0.2595,  0.4507, -0.2106,  ...,  0.4368,  0.0714, -0.3896],
        [ 0.4290, -0.1093, -0.2600,  ...,  0.7974, -0.1187, -0.1620],
        [-0.1887,  0.1056, -0.3274,  ...,  0.7422, -0.1550,  0.3389]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Loss after 2160 Batch is nan 
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBa

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[-0.0537, -0.0397, -0.3462,  ...,  1.1855, -0.0427, -0.2480],
        [ 0.5864,  0.2214, -0.6494,  ...,  0.6631, -0.0733,  0.2808],
        [ 0.0698, -0.0967, -0.3450,  ...,  0.3760,  0.0296,  0.4624],
        [-0.0826, -0.0107, -0.1862,  ...,  0.5767, -0.2942,  0.0326]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[-0.2428, -0.0854, -0.1947,  ...,  0.7842, -0.2778, -0.0203],
        [ 0.3228, -0.0880, -0.5449,  ...,  0.7476, -0.1238,  0.3503],
        [ 0.0227,  0.0554, -0.3218,  ...,  1.0498, -0.1770,  0.0075],
        [ 0.4275,  0.1163, -0.1261,  ...,  0.8623, -0.2998,  0.1677]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.0076,  0.6411, -0.2174,  ...,  0.7573,  0.1116,  0.1238],
        [ 0.0511, -0.1884, -0.4177,  ...,  0.7183,  0.4978, -0.0222],
        [-0.3137, -0.1057, -0.4602,  ...,  0.5176, -0.0770, -0.1201],
        [ 0.1677,  0.1248, -0.1554,  ...,  0.2393, -0.1425,  0.0623]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.5615,  0.6729, -0.3989,  ...,  0.3545, -0.2478,  0.0619],
        [ 0.4182,  0.0777, -0.3606,  ...,  0.8628,  0.0416,  0.1779],
        [ 0.4387,  0.0791, -0.3943,  ...,  0.8560, -0.0334,  0.1472],
        [ 0.1132, -0.0173, -0.2905,  ...,  0.7041,  0.0321,  0.2588]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.3179,  0.3853, -0.4714,  ...,  0.9053,  0.0188,  0.0026],
        [-0.0889, -0.0768, -0.1952,  ...,  0.6260, -0.2644, -0.2524],
        [ 0.2776, -0.0540, -0.0713,  ...,  0.7251, -0.0165,  0.1512],
        [ 0.2285, -0.0786, -0.5542,  ...,  0.8716,  0.0648,  0.2646]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Loss after 2220 Batch is nan 
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBa

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.0400, -0.0793, -0.1377,  ...,  0.6470, -0.1886, -0.1318],
        [ 0.1254, -0.0264, -0.1344,  ...,  0.5488,  0.0461, -0.0064],
        [ 0.1792,  0.2404, -0.2085,  ...,  0.7002, -0.0775,  0.2098],
        [ 0.0500, -0.0205, -0.0964,  ...,  0.5508, -0.1644, -0.0753]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

outputs_teacher tensor([[-0.2345,  0.1686, -0.1151,  ...,  0.9844, -0.1487, -0.2612],
        [-0.1658, -0.1840, -0.3408,  ...,  0.8354,  0.1003,  0.2505],
        [ 0.4058,  0.0764, -0.1240,  ...,  0.6221,  0.0217, -0.1885],
        [ 0.1257,  0.1508, -0.2452,  ...,  0.5996,  0.1178,  0.0539]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 2.4146e-01,  2.0129e-01, -2.2595e-01,  ...,  7.9590e-01,
          9.2208e-05,  2.6489e-01],
        [ 1.6528e-01,  1.4880e-01, -4.0356e-01,  ...,  8.1836e-01,
         -2.8955e-01,  5.2148e-01],
        [ 4.3677e-01,  3.0493e-01, -3.7183e-01,  ...,  7.5586e-01,
          3.4

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.3796, -0.1136, -0.3108,  ...,  0.9111, -0.2341, -0.1772],
        [ 0.1559,  0.1567, -0.2135,  ...,  0.5576, -0.5244,  0.2788],
        [-0.1184, -0.0046, -0.3035,  ...,  1.0000,  0.0328,  0.2366],
        [ 0.3577, -0.2693, -0.1353,  ...,  0.5103, -0.2076,  0.0948]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[-0.1752,  0.2285, -0.4465,  ...,  0.9180,  0.0567, -0.2261],
        [ 0.0644,  0.1747, -0.2349,  ...,  0.7354,  0.2311,  0.1798],
        [-0.0275,  0.2507, -0.4226,  ...,  1.1514,  0.0952, -0.2673],
        [ 0.2196, -0.3103, -0.1937,  ...,  0.4006,  0.0907,  0.5479]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.0714, -0.1028, -0.1024,  ...,  0.3523,  0.1307,  0.3333],
        [-0.0047, -0.0354, -0.1058,  ...,  0.6509,  0.1805,  0.2374],
        [ 0.1327,  0.2742, -0.1255,  ...,  0.5659,  0.1458, -0.2010],
        [ 0.2188,  0.0433, -0.4451,  ...,  0.3875,  0.1562, -0.0734]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Loss after 2280 Batch is nan 
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBa

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.4607,  0.0130, -0.3923,  ...,  0.7290, -0.1714,  0.2522],
        [ 0.1365,  0.0461, -0.5220,  ...,  0.6982,  0.2947,  0.0445],
        [ 0.3435,  0.2181, -0.2327,  ...,  0.5991, -0.1487, -0.3628],
        [ 0.1869,  0.0741, -0.2228,  ...,  1.3086, -0.0646,  0.2737]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

outputs_teacher tensor([[ 0.1576,  0.2209,  0.0091,  ...,  0.6665,  0.0916,  0.5889],
        [ 0.5044,  0.0866, -0.2235,  ...,  0.6973,  0.0751,  0.2235],
        [ 0.1489, -0.2649, -0.2078,  ...,  0.7583, -0.2134,  0.1385],
        [ 0.5962, -0.2800, -0.4678,  ...,  0.6816,  0.3311,  0.2117]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.0380,  0.0296, -0.4829,  ...,  0.8589,  0.2607,  0.2600],
        [ 0.1072,  0.3835,  0.1746,  ...,  0.5444,  0.0184, -0.1222],
        [ 0.5171,  0.1735, -0.4365,  ...,  0.3096,  0.0044,  0.3867],
        [ 0.0281,  0.0600, -0.2344,  ...,  0.6045, -0.1311,  0.4810]],
     

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[-0.0406,  0.5264, -0.2484,  ...,  0.1261, -0.2791,  0.2747],
        [ 0.2849, -0.2539, -0.3049,  ...,  0.4106, -0.1522,  0.4275],
        [ 0.0928,  0.0668, -0.4609,  ...,  1.0039, -0.2698, -0.1207],
        [ 0.3369,  0.0656, -0.2773,  ...,  0.2629, -0.1938,  0.2247]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.1669,  0.0421, -0.3877,  ...,  0.4888, -0.1943, -0.0459],
        [ 0.2778,  0.0520, -0.2690,  ...,  0.7417, -0.3914, -0.1152],
        [ 0.3298, -0.0263, -0.3645,  ...,  0.3772,  0.1005,  0.2329],
        [ 0.2142,  0.2703, -0.0663,  ...,  0.5723, -0.1818, -0.0071]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.2961,  0.2137, -0.1664,  ...,  0.6123, -0.0189,  0.3369],
        [ 0.3193, -0.1147, -0.2349,  ...,  0.7891,  0.1526, -0.1886],
        [ 0.1663,  0.0970, -0.4473,  ...,  0.6094,  0.0305, -0.0630],
        [ 0.2214,  0.1005, -0.1909,  ...,  0.6875, -0.0765, -0.0884]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.4487,  0.1621, -0.3467,  ...,  0.6191, -0.1476,  0.2021],
        [-0.2457,  0.1447, -0.1412,  ...,  0.7808,  0.0403, -0.1553],
        [ 0.6411,  0.0764, -0.1259,  ...,  0.6392,  0.3254,  0.1262],
        [ 0.1984,  0.0372, -0.1726,  ...,  0.6807,  0.0035,  0.1713]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.6802,  0.1246, -0.3142,  ...,  0.3840, -0.0598,  0.4272],
        [ 0.2656, -0.0902, -0.1482,  ...,  1.0283, -0.2407, -0.0085],
        [-0.0144, -0.0222, -0.2810,  ...,  0.8359, -0.0509, -0.0788],
        [ 0.0280,  0.2764, -0.4758,  ...,  0.9697, -0.2437,  0.1716]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Loss after 2370 Batch is nan 
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBa

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.2400, -0.1388, -0.0549,  ...,  0.4258, -0.3472,  0.0778],
        [ 0.4094,  0.2195, -0.4580,  ...,  0.7612, -0.1813,  0.0572],
        [ 0.2483,  0.0604, -0.0732,  ...,  0.8682,  0.1646, -0.1009],
        [ 0.2847,  0.3870, -0.0863,  ...,  0.8911, -0.3220,  0.0171]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.1096,  0.4016, -0.0401,  ...,  0.4060, -0.1498, -0.1420],
        [ 0.1316,  0.1022, -0.5034,  ...,  0.8120, -0.1583, -0.0565],
        [ 0.4832,  0.3198, -0.3125,  ...,  0.3660,  0.3352,  0.0022],
        [ 0.3677,  0.0861, -0.4441,  ...,  0.5996, -0.1968,  0.2345]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.4119, -0.1377, -0.1810,  ...,  0.6533,  0.1235,  0.0718],
        [ 0.0278,  0.1532, -0.2686,  ...,  0.6411, -0.1453,  0.1091],
        [ 0.3977,  0.0343, -0.0609,  ...,  0.2583, -0.0290,  0.2493],
        [-0.0077,  0.1395, -0.3059,  ...,  0.9673, -0.0816,  0.1469]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.1630,  0.0485, -0.1095,  ...,  0.8374, -0.1284,  0.0978],
        [ 0.0525,  0.1142, -0.2440,  ...,  1.0752,  0.1056,  0.2646],
        [ 0.3569,  0.1203, -0.2216,  ...,  0.6538, -0.0086,  0.2251],
        [ 0.0507, -0.1438, -0.0653,  ...,  0.7710, -0.0967,  0.1560]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

outputs_teacher tensor([[ 0.0276,  0.4253, -0.7192,  ...,  1.0098,  0.3508,  0.2595],
        [ 0.8789,  0.3147, -0.3030,  ...,  0.6001,  0.1433, -0.0401],
        [ 0.3955,  0.1521, -0.3501,  ...,  0.7935, -0.0419, -0.2529],
        [ 0.4475, -0.1915, -0.0913,  ...,  0.8066, -0.0540,  0.1835]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Loss after 2430 Batch is nan 
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[-0.2407,  0.0268,  0.0015,  ...,  0.9976, -0.1126, -0.2891],
        [ 0.1031,  0.0309, -0.1691,  ...,  0.6694, -0.4016,  0.4600],
        [ 0.2098,  0.3782,  0.3433,  ...,  0.3777, -0.0592, -0.2058],
        [ 0.0036,  0.0037, -0.0718,  ...,  0.4

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.1036,  0.0673, -0.2452,  ...,  0.3611, -0.0175,  0.1974],
        [ 0.1022,  0.3799, -0.3125,  ...,  0.6719, -0.0674, -0.0617],
        [ 0.4026,  0.2419, -0.3027,  ...,  0.8535,  0.1309,  0.1981],
        [ 0.1243,  0.2529, -0.1730,  ...,  0.6416,  0.0968, -0.1181]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.4460,  0.5903, -0.3645,  ...,  0.4736,  0.1527, -0.0895],
        [ 0.1129,  0.2258, -0.0779,  ...,  0.4519,  0.0056, -0.0664],
        [ 0.0747, -0.0705, -0.3008,  ...,  0.7500, -0.2837,  0.1816],
        [ 0.2546,  0.1429, -0.2269,  ...,  0.7729, -0.3716,  0.2083]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.1006,  0.3667, -0.8276,  ...,  0.7441,  0.2812,  0.0185],
        [ 0.2507,  0.3618, -0.3564,  ...,  0.8945, -0.0755,  0.0340],
        [ 0.3948, -0.0869, -0.3157,  ...,  0.3396, -0.1410,  0.0827],
        [ 0.2030,  0.3896, -0.3093,  ...,  0.3650, -0.0797,  0.2864]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.2061, -0.1772, -0.3621,  ...,  0.7100,  0.0142,  0.2747],
        [ 0.4612,  0.1166, -0.3672,  ...,  0.4631, -0.2354, -0.0206],
        [ 0.0196,  0.1392, -0.1931,  ...,  0.4705, -0.0044,  0.2406],
        [-0.0452,  0.3330, -0.2883,  ...,  0.4900,  0.0931,  0.0747]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.2360,  0.3127, -0.4390,  ...,  0.8286,  0.1200, -0.1847],
        [ 0.2659, -0.0333, -0.4666,  ...,  0.6558,  0.2219,  0.3857],
        [ 0.3931, -0.2263, -0.0641,  ...,  0.4346, -0.1819,  0.2115],
        [ 0.2996, -0.0352, -0.3638,  ...,  0.4761,  0.0534, -0.0807]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Loss after 2490 Batch is nan 
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
out

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.0286,  0.1799, -0.0293,  ...,  1.0078, -0.0194, -0.0747],
        [ 0.1471,  0.3638, -0.1089,  ...,  0.5972, -0.1481, -0.1541],
        [-0.0545,  0.0714, -0.3953,  ...,  0.7393,  0.0036,  0.1248],
        [-0.2244, -0.2681, -0.2201,  ...,  0.7095, -0.0343,  0.1063]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.4309, -0.2612, -0.4543,  ...,  0.5269, -0.1851,  0.3560],
        [-0.0599,  0.0132, -0.4241,  ...,  0.6543, -0.2556,  0.1005],
        [ 0.4304,  0.0046, -0.5342,  ...,  0.5645, -0.2517,  0.3301],
        [-0.0595, -0.1129, -0.2343,  ...,  0.7935, -0.2532,  0.1990]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.3757,  0.1964, -0.4880,  ...,  0.7944,  0.2255,  0.0613],
        [ 0.1079, -0.0873, -0.3667,  ...,  0.4897,  0.1472,  0.3735],
        [ 0.2732,  0.4358, -0.3452,  ...,  0.9395, -0.1031,  0.0193],
        [ 0.3354,  0.1738, -0.3628,  ...,  1.1270, -0.0408,  0.2078]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[-0.0590,  0.0661, -0.6455,  ...,  0.7568, -0.2266, -0.0069],
        [ 0.5205,  0.0056, -0.3945,  ...,  0.6270, -0.2991,  0.1290],
        [ 0.2825,  0.0169, -0.2148,  ...,  0.4995, -0.0657, -0.1092],
        [-0.1770, -0.1967, -0.4104,  ...,  0.7983,  0.2220,  0.2681]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Loss after 2550 Batch is nan 
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.3745,  0.0777, -0.4048,  ...,  0.2266, -0.0257,  0.2781],
        [ 0.4236,  0.3108, -0.4495,  ...,  0.6719, -0.1422, -0.2383],
        [ 0.0748,  0.1558, -0.1505,  ...,  0.3496, -0.0745,  0.2786],
        [ 0.3267, -0.0786, -0.2659,  ...,  0.7568, -0.1973,  0.2384]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., n

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.3984,  0.1292, -0.5264,  ...,  0.6626, -0.1102,  0.0016],
        [-0.2551, -0.3069, -0.3542,  ...,  0.9341, -0.2206,  0.0799],
        [ 0.7383,  0.4866, -0.3428,  ...,  0.3801,  0.0561,  0.4753],
        [ 0.3281,  0.3916, -0.4624,  ...,  0.4385, -0.0524,  0.0142]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

outputs_teacher tensor([[ 0.3423, -0.1213, -0.2805,  ...,  0.5254,  0.0244, -0.1779],
        [ 0.4819,  0.6479, -0.0838,  ...,  0.4895,  0.0800,  0.2732],
        [ 0.4448,  0.0109, -0.3235,  ...,  0.6904, -0.1786,  0.1486],
        [-0.0979,  0.1183, -0.1713,  ...,  0.7119,  0.1315, -0.0024]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.2634,  0.2878, -0.3984,  ...,  0.5962, -0.3052,  0.1232],
        [-0.1039,  0.1196, -0.1213,  ...,  0.7915, -0.0646,  0.0081],
        [ 0.0919,  0.3308, -0.2546,  ...,  1.0518, -0.2008,  0.4351],
        [ 0.5679, -0.1066, -0.6006,  ...,  0.5547, -0.3379,  0.2581]],
     

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[-0.0328,  0.3220, -0.3303,  ...,  1.0117, -0.2023,  0.0837],
        [ 0.2054,  0.4094, -0.2803,  ...,  0.6196, -0.0565,  0.0522],
        [-0.1411, -0.1316, -0.0267,  ...,  0.7661, -0.2307, -0.1241],
        [ 0.4216,  0.1940, -0.1125,  ...,  0.5161,  0.2603,  0.0745]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.2566,  0.3264, -0.1282,  ...,  0.9341, -0.0957,  0.0122],
        [ 0.1252,  0.2661,  0.0482,  ...,  0.5093, -0.0959,  0.0862],
        [ 0.2483,  0.2445, -0.3613,  ...,  0.5352, -0.0957, -0.0818],
        [-0.1693, -0.0410, -0.1769,  ...,  0.5298, -0.1237, -0.0111]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[-0.0061,  0.1570, -0.3506,  ...,  0.6104, -0.2781,  0.1087],
        [ 0.2771, -0.2402, -0.3818,  ...,  0.5400, -0.0213,  0.3372],
        [ 0.3040, -0.5503, -0.3647,  ...,  0.3845,  0.2134, -0.1385],
        [-0.1932, -0.1562, -0.1602,  ...,  0.6768, -0.2512,  0.1384]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.4226,

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.0717,  0.2101, -0.2527,  ...,  0.7573, -0.0420,  0.2112],
        [ 0.3047,  0.4412, -0.4375,  ...,  0.7300, -0.0930, -0.1608],
        [ 0.3098, -0.1870, -0.3684,  ...,  0.7490,  0.0209,  0.1565],
        [ 0.3252,  0.4841, -0.3567,  ...,  0.8296, -0.2168,  0.0153]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.1818,  0.1063, -0.3914,  ...,  1.0127,  0.0372, -0.1936],
        [ 0.1219, -0.0651, -0.4829,  ...,  0.6860, -0.3630, -0.0544],
        [-0.0736,  0.0794, -0.1598,  ...,  0.8296, -0.1003,  0.0273],
        [ 0.0298,  0.1613, -0.3669,  ...,  0.6738,  0.0919,  0.0238]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.3110,  0.0658, -0.3110,  ...,  0.6216, -0.1538,  0.0976],
        [-0.0930, -0.1921, -0.3853,  ...,  1.2158, -0.0812,  0.2054],
        [-0.2161, -0.0307,  0.1024,  ...,  0.5840,  0.0027, -0.0844],
        [-0.2617, -0.0663,  0.0013,  ...,  0.7485, -0.3530, -0.0641]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.0508,  0.0541, -0.4683,  ...,  1.0244, -0.0573, -0.0605],
        [ 0.0598, -0.1770, -0.3445,  ...,  1.0635,  0.2483,  0.1309],
        [ 0.1548,  0.1494, -0.2288,  ...,  0.9932,  0.1960,  0.2346],
        [ 0.3672, -0.3782, -0.2947,  ...,  0.7251,  0.1967,  0.3108]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.3281,  0.0595, -0.4280,  ...,  0.4951,  0.1991,  0.1631],
        [-0.0933, -0.0344, -0.4529,  ...,  0.5220,  0.1870,  0.1364],
        [ 0.3691,  0.3828, -0.0969,  ...,  0.7559,  0.1636,  0.1577],
        [-0.1923,  0.3220,  0.0615,  ...,  0.6616, -0.1179, -0.0431]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.2646,

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[-0.2484, -0.1307, -0.2888,  ...,  0.6323,  0.1946, -0.4292],
        [ 0.3770, -0.0048, -0.1027,  ...,  0.6377,  0.2469,  0.4102],
        [ 0.2015,  0.1816, -0.3818,  ...,  0.9287, -0.0436,  0.1881],
        [ 0.1056,  0.2238, -0.3652,  ...,  0.4585, -0.2747,  0.1888]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.1338, -0.0377, -0.2151,  ...,  0.6562,  0.0530, -0.0959],
        [ 0.1754,  0.1656, -0.4746,  ...,  1.1436,  0.2854,  0.2452],
        [ 0.4685,  0.1461, -0.1119,  ...,  0.4778,  0.0134, -0.1486],
        [ 0.1841,  0.0859, -0.3171,  ...,  0.4753, -0.1285,  0.1731]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.3645,  0.3337, -0.2280,  ...,  1.1465, -0.3069, -0.0770],
        [ 0.5986,  0.2229, -0.4038,  ...,  0.8384,  0.0218, -0.0447],
        [ 0.5127,  0.0752, -0.3430,  ...,  0.6089,  0.0676, -0.0905],
        [ 0.4866,  0.2703, -0.3752,  ...,  0.8862,  0.0085,  0.1550]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.3079,  0.0626, -0.5898,  ...,  0.6821,  0.1792,  0.2113],
        [ 0.1721,  0.2275, -0.3586,  ...,  0.7139, -0.1887,  0.1183],
        [ 0.3508,  0.2798, -0.2076,  ...,  0.8687,  0.1198,  0.1097],
        [ 0.4187, -0.3086, -0.1964,  ...,  0.6558, -0.1914,  0.2324]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.4473,

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 3.7109e-01,  7.5439e-02, -4.3506e-01,  ...,  5.1416e-01,
         -6.7505e-02,  7.5195e-02],
        [ 1.3135e-01,  4.9713e-02, -2.3120e-01,  ...,  9.1992e-01,
         -2.5879e-01,  2.4023e-01],
        [-1.1301e-04,  1.0583e-01, -2.3315e-01,  ...,  9.2041e-01,
         -2.8369e-01, -2.8046e-02],
        [ 3.7183e-01,  4.1821e-01, -1.4734e-01,  ...,  5.2441e-01,
          1.0498e-01, -2.1594e-01]], device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [n

outputs_teacher tensor([[ 1.3330e-01,  2.4414e-01, -1.6443e-01,  ...,  4.7827e-01,
          6.9275e-02, -5.4199e-02],
        [ 2.8634e-04,  1.9421e-01, -3.1616e-01,  ...,  1.1016e+00,
         -6.4880e-02,  2.5732e-01],
        [ 5.7959e-01,  4.2700e-01, -2.1985e-01,  ...,  5.8301e-01,
          1.2927e-01,  4.2938e-02],
        [ 4.6118e-01,  1.6724e-01, -1.9592e-01,  ...,  6.9824e-01,
         -2.5986e-02,  2.8711e-01]], device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.4543,  0.3545,  0.0426,  ...,  0.3894, -0.1921,  0.2382],
        [-0.2236,  0.0944, -0.3511,  ...,  0.6953,  0.3450,  0.1354],
        [ 0.1479,  0.

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.3269,  0.2483, -0.1973,  ...,  0.7661, -0.2783,  0.2874],
        [ 0.1035, -0.0834, -0.5566,  ...,  0.9414, -0.1301, -0.0143],
        [ 0.4268, -0.0619, -0.2822,  ...,  0.7041, -0.0193,  0.1490],
        [ 0.4399,  0.3474, -0.1287,  ...,  0.7295, -0.1775, -0.0229]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.1749,  0.3706, -0.0756,  ...,  0.7583, -0.1127,  0.0745],
        [ 0.4131,  0.3862, -0.3118,  ...,  0.8892, -0.2578, -0.0299],
        [ 0.0357, -0.1362, -0.2209,  ...,  0.5244,  0.1365,  0.1926],
        [ 0.0327, -0.1093, -0.1364,  ...,  0.5376,  0.0737,  0.3291]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.3567, -0.2003, -0.2418,  ...,  0.6914,  0.0039,  0.6265],
        [ 0.1716,  0.1213, -0.3855,  ...,  1.0127,  0.0935,  0.3440],
        [ 0.1661,  0.2094, -0.1265,  ...,  0.8271,  0.0351, -0.0081],
        [ 0.1227,  0.1835, -0.0175,  ...,  0.7485,  0.0443, -0.1821]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.4890,  0.1240, -0.0992,  ...,  0.8496, -0.3369,  0.2361],
        [ 0.2744, -0.0642, -0.2693,  ...,  0.6538,  0.0704,  0.1335],
        [-0.0804, -0.3030, -0.2312,  ...,  0.5117, -0.1213,  0.3550],
        [ 0.2325,  0.2512, -0.1345,  ...,  0.6553,  0.0953,  0.1089]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.3557,  0.2656, -0.1594,  ...,  0.4290, -0.1997,  0.0650],
        [ 0.2233,  0.1394, -0.4954,  ...,  0.7002, -0.0505,  0.0110],
        [-0.1062,  0.2140, -0.1154,  ...,  0.7446, -0.1888, -0.1111],
        [ 0.2576, -0.0426, -0.2305,  ...,  0.4058,  0.1560,  0.1371]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.0238,  0.1017, -0.3677,  ...,  0.5952, -0.0816,  0.1810],
        [ 0.2883,  0.5322, -0.4780,  ...,  0.9058, -0.0598, -0.0054],
        [-0.0362,  0.1141, -0.2408,  ...,  0.6499, -0.0523,  0.5054],
        [ 0.0051, -0.0640, -0.2202,  ...,  0.6406, -0.2108,  0.2805]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.0230,  0.5830,  0.0177,  ...,  0.4673,  0.0630, -0.1082],
        [ 0.4197,  0.1577, -0.2078,  ...,  0.7500, -0.0907,  0.1483],
        [ 0.4958,  0.2236, -0.2463,  ...,  0.4041, -0.0599, -0.2798],
        [ 0.0146,  0.3345, -0.1631,  ...,  0.5732, -0.3108, -0.3096]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.2871, -0.2144, -0.1249,  ...,  0.5337,  0.0479,  0.0471],
        [ 0.3652,  0.4575, -0.3174,  ...,  0.6870, -0.3245,  0.3545],
        [ 0.3770,  0.2074, -0.3774,  ...,  0.7456,  0.1053,  0.4014],
        [ 0.6294, -0.0776, -0.3025,  ...,  0.3164, -0.2246,  0.0575]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.2305,  0.3394, -0.0705,  ...,  0.9082,  0.0164,  0.0193],
        [ 0.4041,  0.4365, -0.3025,  ...,  0.3521,  0.2037, -0.2267],
        [ 0.3167, -0.1617, -0.1165,  ...,  0.5410,  0.4724,  0.1245],
        [ 0.4504,  0.1350, -0.3950,  ...,  0.7476, -0.0278, -0.0367]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[-0.1373, -0.1694, -0.2720,  ...,  0.9551, -0.2559,  0.1721],
        [-0.0779,  0.3186, -0.1628,  ...,  0.4229, -0.1208,  0.4373],
        [ 0.3113, -0.1542, -0.0874,  ...,  0.5146,  0.1399, -0.1589],
        [ 0.1082, -0.0229, -0.3655,  ...,  0.9727, -0.0812,  0.3049]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.1862,  0.2450, -0.3862,  ...,  0.7788,  0.1453, -0.0319],
        [ 0.3962,  0.1377, -0.1732,  ...,  0.3882,  0.1880,  0.4619],
        [ 0.5181, -0.2734, -0.6226,  ...,  0.6851,  0.2458, -0.2161],
        [ 0.2627,  0.0547, -0.3552,  ...,  0.8037, -0.0208,  0.5566]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[-0.2465, -0.0481, -0.1962,  ...,  0.6440, -0.1266,  0.1982],
        [-0.1255, -0.1373, -0.3201,  ...,  0.5825, -0.1698, -0.0538],
        [ 0.2169,  0.2556, -0.2866,  ...,  0.9507, -0.1344,  0.0637],
        [ 0.1429, -0.0789, -0.3582,  ...,  0.9517, -0.1049, -0.0435]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

outputs_teacher tensor([[ 0.1890,  0.0836, -0.4971,  ...,  0.5234,  0.2310,  0.2896],
        [ 0.1500,  0.1105, -0.4680,  ...,  0.7607,  0.2935,  0.1694],
        [ 0.1870,  0.0692, -0.2339,  ...,  0.6465, -0.0105,  0.2861],
        [ 0.2673,  0.0017, -0.2583,  ...,  0.4778,  0.0337,  0.1776]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.2179,  0.1095, -0.4194,  ...,  1.1035, -0.3101,  0.0547],
        [ 0.1667,  0.1597, -0.2546,  ...,  0.7700,  0.1752, -0.0302],
        [-0.0940,  0.2576, -0.5459,  ...,  0.8257, -0.4380, -0.1043],
        [ 0.2820,  0.0359, -0.3896,  ...,  0.5171, -0.0136, -0.0554]],
     

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[-0.1131,  0.0751, -0.2742,  ...,  0.5674, -0.0476,  0.3071],
        [-0.0524,  0.1091, -0.4805,  ...,  0.3457, -0.1140,  0.0528],
        [-0.0270, -0.2573, -0.2032,  ...,  1.0273,  0.1075,  0.0975],
        [ 0.2881, -0.0136, -0.4631,  ...,  0.7202,  0.3909,  0.2954]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.1819, -0.0352, -0.2191,  ...,  0.2881,  0.2023,  0.1289],
        [ 0.3940,  0.2568, -0.3359,  ...,  0.7612,  0.3103, -0.0672],
        [ 0.1893,  0.3218, -0.2346,  ...,  0.6685,  0.2715,  0.3562],
        [ 0.2534,  0.0635, -0.4258,  ...,  0.6938,  0.2343,  0.1178]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Loss after 2930 Batch is nan 
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBa

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.2847,  0.3369, -0.3608,  ...,  0.9277,  0.0073,  0.1639],
        [ 0.2240,  0.0723, -0.5225,  ...,  0.9116,  0.3958,  0.3704],
        [ 0.1586, -0.3196, -0.2491,  ...,  0.7090,  0.0398,  0.1854],
        [ 0.3655, -0.1008, -0.3542,  ...,  0.6260,  0.0975,  0.1504]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 8.8989e-02,  1.2134e-01, -2.1741e-01,  ...,  9.4922e-01,
         -9.1248e-02,  6.7676e-01],
        [ 4.8511e-01,  4.5586e-04, -3.6377e-01,  ...,  9.1943e-01,
          1.0870e-01,  6.6895e-02],
        [ 1.1475e-01,  6.8054e-02, -2.6855e-01,  ...,  5.8643e-01,
         -8.0566e-02, -1.9165e-01],
        [ 2.4292e-01,  1.0455e-01, -5.1807e-01,  ...,  8.3789e-01,
          3.5059e-01,  3.3838e-01]], device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [n

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.1761,  0.1648, -0.3447,  ...,  0.7529, -0.0076,  0.3416],
        [ 0.4453, -0.3040, -0.6733,  ...,  0.6191, -0.1470,  0.2981],
        [-0.0373,  0.2600, -0.3008,  ...,  0.7983,  0.1304,  0.1503],
        [-0.0077,  0.2566, -0.1049,  ...,  1.0625, -0.0824,  0.4395]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.3464,  0.1630, -0.0236,  ...,  0.8604,  0.0681, -0.5176],
        [ 0.3296,  0.2406, -0.1837,  ...,  0.6187, -0.1412,  0.2671],
        [ 0.1354,  0.6626, -0.2720,  ...,  0.8628,  0.2220,  0.0061],
        [ 0.6328,  0.2615, -0.3550,  ...,  0.7383, -0.1261,  0.1600]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[-0.0459,  0.3508, -0.4089,  ...,  0.6401,  0.3250,  0.1974],
        [ 0.3376, -0.0967, -0.1907,  ...,  0.6528, -0.0435, -0.0213],
        [-0.1768,  0.1265, -0.2551,  ...,  0.8589,  0.0341, -0.2158],
        [ 0.1879,  0.2111, -0.4119,  ...,  0.7773, -0.0244,  0.1998]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Loss after 2990 Batch is nan 
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBa

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.3606,  0.0497, -0.3875,  ...,  0.2407, -0.2109,  0.2496],
        [ 0.0623, -0.1194, -0.4153,  ...,  0.9180,  0.1372, -0.1779],
        [-0.2876,  0.1919, -0.4468,  ...,  0.7173, -0.0371,  0.3176],
        [ 0.2561, -0.3577, -0.2852,  ...,  0.2566, -0.0693, -0.0393]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.2347, -0.0299, -0.3123,  ...,  0.8867, -0.1123,  0.2844],
        [ 0.3428,  0.4382, -0.3860,  ...,  0.7598, -0.2040,  0.1175],
        [ 0.1996, -0.1720, -0.4270,  ...,  1.0127, -0.0169,  0.2426],
        [ 0.3867, -0.0621, -0.2708,  ...,  0.5728,  0.0831, -0.0322]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[-0.4417, -0.0287, -0.2242,  ...,  0.7002, -0.1324, -0.2610],
        [ 0.3069, -0.2284, -0.2458,  ...,  0.3999,  0.2708,  0.0382],
        [ 0.5098, -0.4302, -0.3457,  ...,  0.6099,  0.1444,  0.3538],
        [ 0.2778, -0.2294, -0.3713,  ...,  0.6372,  0.1062,  0.1754]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.0614, -0.0817, -0.5474,  ...,  0.5430,  0.1028, -0.1262],
        [ 0.3638,  0.1731, -0.4053,  ...,  0.5532,  0.4653,  0.0883],
        [ 0.4126, -0.0236, -0.6162,  ...,  0.6025,  0.1580,  0.2698],
        [ 0.3442,  0.4675, -0.6626,  ...,  0.3589, -0.0974,  0.2101]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.0754,  0.3176,  0.0313,  ...,  0.4507, -0.1448,  0.1771],
        [ 0.1832,  0.0401, -0.1234,  ...,  0.6235,  0.0252,  0.6152],
        [ 0.0376,  0.2009, -0.1758,  ...,  0.5938, -0.0065,  0.0514],
        [ 0.1055,  0.2920, -0.4072,  ...,  0.8369, -0.0545,  0.1799]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Loss after 3050 Batch is nan 
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBa

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 9.7229e-02,  7.8003e-02, -1.8408e-01,  ...,  8.6133e-01,
         -2.9736e-01,  3.4814e-01],
        [ 1.6174e-01, -1.2158e-01, -1.3696e-01,  ...,  5.0537e-01,
         -5.6505e-04,  2.9565e-01],
        [ 2.8760e-01,  3.0469e-01,  3.8818e-02,  ...,  6.3574e-01,
         -3.4326e-01,  1.3748e-02],
        [ 2.1716e-01,  1.5857e-01, -2.2449e-01,  ...,  5.5518e-01,
         -1.7175e-01, -1.8579e-01]], device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [n

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.2593,  0.0858, -0.2191,  ...,  0.9482, -0.1104,  0.2087],
        [-0.1059,  0.3259, -0.3892,  ...,  0.6050, -0.0209, -0.0988],
        [ 0.4434,  0.3154, -0.5327,  ...,  1.0176, -0.0700, -0.0172],
        [ 0.2430, -0.2173, -0.0609,  ...,  0.6528, -0.0352,  0.3450]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.2192, -0.3975, -0.1626,  ...,  0.6567, -0.0095,  0.2671],
        [-0.0011,  0.0417, -0.2258,  ...,  0.7803, -0.2371, -0.1554],
        [ 0.1494,  0.1155, -0.3665,  ...,  0.4175, -0.3540,  0.3062],
        [ 0.5815,  0.6484,  0.0224,  ...,  0.6694,  0.1262, -0.1401]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.2717,  0.2299, -0.5884,  ...,  0.7095,  0.1077,  0.0079],
        [ 0.2290,  0.3096, -0.2126,  ...,  0.3870, -0.0014, -0.3760],
        [ 0.1486,  0.1141, -0.4551,  ...,  0.5630, -0.3853,  0.2737],
        [ 0.4119,  0.1522, -0.2698,  ...,  0.9556,  0.0580,  0.3315]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.4490,  0.0541, -0.4028,  ...,  0.6841,  0.1805,  0.1385],
        [ 0.3794,  0.0573, -0.2883,  ...,  0.8560,  0.1763,  0.1979],
        [ 0.0241, -0.0092, -0.0333,  ...,  0.4961, -0.1609, -0.0659],
        [ 0.3975,  0.0218, -0.3682,  ...,  0.4417,  0.0582,  0.0519]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[-0.0731,  0.0555, -0.3130,  ...,  0.8110,  0.0399,  0.1337],
        [ 0.1689,  0.4026, -0.4099,  ...,  0.8789, -0.2119, -0.1619],
        [ 0.3682, -0.0533, -0.2954,  ...,  0.6919, -0.0988,  0.0115],
        [ 0.1394,  0.0490, -0.1388,  ...,  0.8877,  0.0481,  0.0170]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.5093,  0.2219, -0.2844,  ...,  0.7632,  0.2321,  0.1304],
        [-0.0673, -0.0626, -0.1566,  ...,  1.0088,  0.2598, -0.0448],
        [ 0.1309,  0.1909, -0.4409,  ...,  0.7676, -0.1737, -0.0690],
        [ 0.2041,  0.1313, -0.3550,  ...,  0.4500, -0.2676,  0.1842]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.3618,  0.3723, -0.1959,  ...,  0.7456,  0.0879,  0.1012],
        [ 0.2330, -0.1558, -0.6431,  ...,  0.7168, -0.4058, -0.0848],
        [ 0.3479,  0.1098, -0.4160,  ...,  0.2744,  0.1165, -0.0620],
        [-0.0591, -0.0400, -0.5957,  ...,  0.6592, -0.1448,  0.2037]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.0833,  0.1649, -0.5244,  ...,  0.6948, -0.0285,  0.1661],
        [ 0.4292, -0.0659, -0.6084,  ...,  0.4795, -0.0579,  0.2467],
        [ 0.1523, -0.2294, -0.3379,  ...,  0.6401, -0.1292,  0.1127],
        [ 0.3274, -0.0986, -0.2568,  ...,  0.8364, -0.0042,  0.2693]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.1884, -0.1442, -0.3357,  ...,  0.3901,  0.0673,  0.0498],
        [ 0.0720, -0.0668, -0.1364,  ...,  0.9448, -0.3062, -0.0231],
        [ 0.0809,  0.1392, -0.5405,  ...,  0.3894,  0.0110, -0.1113],
        [ 0.1656,  0.0746, -0.3289,  ...,  1.1094,  0.0357, -0.2242]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.1493, -0.2788, -0.5005,  ...,  0.4958, -0.0540, -0.2053],
        [ 0.2832,  0.0751, -0.2817,  ...,  0.6987, -0.0547,  0.0921],
        [ 0.0143,  0.1543, -0.3577,  ...,  0.5215, -0.0141, -0.2081],
        [ 0.1392, -0.2314, -0.1305,  ...,  0.7227, -0.0971, -0.3328]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[-0.0325, -0.0770, -0.0228,  ...,  0.4846, -0.0936, -0.1934],
        [ 0.0550,  0.4360, -0.1473,  ...,  1.0020,  0.1587, -0.1098],
        [ 0.0427, -0.0894, -0.2249,  ...,  0.7661, -0.1715, -0.2448],
        [ 0.3486,  0.3193, -0.3464,  ...,  0.8765,  0.1571,  0.1199]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 1.7053e-01,  2.8540e-01, -1.3647e-01,  ...,  7.2168e-01,
         -2.2049e-02, -1.9092e-01],
        [ 8.7341e-02,  4.3384e-01, -1.9641e-01,  ...,  1.0254e+00,
          2.8687e-01,  1.7102e-01],
        [ 9.8389e-02,  5.4321e-02, -3.8257e-01,  ...,  5.7422e-01,
         -6.7139e-04,  2.7905e-01],
        [ 3.9575e-01,  1.7773e-01, -2.8833e-01,  ...,  4.8218e-01,
         -1.2122e-01, -2.2354e-02]], device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [n

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.4272,  0.0319, -0.5737,  ...,  0.7114, -0.3047,  0.0510],
        [ 0.2908, -0.2235, -0.3894,  ...,  0.8857,  0.1176,  0.0627],
        [ 0.0901,  0.4436, -0.4214,  ...,  0.7900, -0.0995, -0.0593],
        [ 0.3508,  0.4189, -0.2225,  ...,  0.5479, -0.0337,  0.1505]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.1643,  0.1159, -0.4478,  ...,  1.1143,  0.0369,  0.0334],
        [-0.0906,  0.1053, -0.3867,  ...,  0.6206, -0.0971, -0.0140],
        [ 0.1422, -0.0441, -0.3821,  ...,  0.5996, -0.0746,  0.0643],
        [-0.1503, -0.0621, -0.5049,  ...,  0.4583,  0.2532,  0.1042]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.1399,  0.0825, -0.4512,  ...,  0.6719,  0.0484,  0.3403],
        [ 0.3120,  0.1259, -0.3606,  ...,  0.9243,  0.1681, -0.1598],
        [ 0.1349,  0.3599, -0.5454,  ...,  0.7476, -0.0547,  0.0514],
        [ 0.2252,  0.4692, -0.2900,  ...,  0.8823,  0.1198,  0.3223]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.0643, -0.1895, -0.1969,  ...,  0.8447,  0.0064, -0.0644],
        [ 0.0605,  0.1862, -0.3459,  ...,  0.9507,  0.0036,  0.4436],
        [ 0.1978,  0.0903, -0.6909,  ...,  0.8340, -0.1891,  0.2429],
        [ 0.4595, -0.3203, -0.2406,  ...,  0.4597,  0.2346,  0.2308]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.3665,  0.2115, -0.1836,  ...,  0.5771, -0.0731, -0.0828],
        [ 0.1377,  0.2124, -0.3521,  ...,  0.5171, -0.1116,  0.0467],
        [ 0.1213,  0.4272,  0.0403,  ...,  0.7075,  0.1843,  0.0515],
        [-0.1827,  0.3408, -0.1979,  ...,  1.0400, -0.2430,  0.1523]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.3943, -0.2808, -0.2583,  ...,  0.8931,  0.3035, -0.1653],
        [-0.2389,  0.3198, -0.0787,  ...,  0.1335,  0.1720,  0.2097],
        [ 0.2181,  0.5171, -0.4832,  ...,  0.3496, -0.1316, -0.0455],
        [ 0.1857, -0.0721, -0.1484,  ...,  0.2554,  0.1981, -0.0261]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.4006,  0.5596, -0.1478,  ...,  0.4758,  0.0847,  0.2097],
        [ 0.0209, -0.1461, -0.0193,  ...,  0.7710, -0.0613, -0.0374],
        [ 0.3672,  0.2228, -0.2932,  ...,  0.1790,  0.0866, -0.0505],
        [ 0.1586,  0.1105, -0.3542,  ...,  0.5464, -0.2756,  0.1141]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.2026,  0.0854, -0.4673,  ...,  0.6226, -0.0468,  0.3003],
        [ 0.2451,  0.1710, -0.3145,  ...,  0.7144,  0.0795,  0.4277],
        [ 0.0750, -0.2357, -0.2866,  ...,  0.6245, -0.2639,  0.2722],
        [ 0.2932,  0.3372, -0.1980,  ...,  0.5459, -0.3379,  0.1130]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[-0.5259, -0.2502,  0.0482,  ...,  0.8340, -0.0171,  0.2095],
        [ 0.0257, -0.2014, -0.0242,  ...,  0.3796,  0.0238,  0.3911],
        [ 0.2166, -0.0300, -0.3989,  ...,  0.3416, -0.1102,  0.1545],
        [ 0.0725, -0.2104, -0.5171,  ...,  0.8750,  0.0759, -0.0309]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.1514,  0.0459, -0.5244,  ...,  0.3171, -0.0070,  0.1871],
        [ 0.4084, -0.3818, -0.2808,  ...,  0.8330, -0.2142,  0.2429],
        [-0.1771,  0.1664, -0.4492,  ...,  1.0010,  0.0113,  0.2751],
        [ 0.3699, -0.0663, -0.3367,  ...,  0.5293,  0.0310,  0.1613]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.2532, -0.2048, -0.2097,  ...,  0.6230,  0.0109,  0.1439],
        [ 0.3767,  0.2546, -0.3906,  ...,  0.7622, -0.1189, -0.1665],
        [ 0.3469,  0.0820, -0.0227,  ...,  0.6445, -0.1070, -0.0244],
        [ 0.3906,  0.1221, -0.2771,  ...,  0.6777,  0.0652, -0.1375]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[-0.2362,  0.5576, -0.5464,  ...,  0.5913,  0.1259, -0.3994],
        [-0.0614, -0.1237, -0.3376,  ...,  0.7295, -0.1683,  0.3420],
        [ 0.2233, -0.0130, -0.3572,  ...,  0.6934, -0.1266,  0.3521],
        [ 0.3477, -0.1550, -0.3318,  ...,  0.6499, -0.0132, -0.1364]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

outputs_teacher tensor([[ 0.2881, -0.0511, -0.1379,  ...,  0.3000, -0.0218,  0.4954],
        [ 0.3689,  0.0612, -0.3093,  ...,  0.8130, -0.0930,  0.0816],
        [ 0.2988, -0.4927, -0.4114,  ...,  0.6538,  0.0022, -0.1063],
        [-0.0255,  0.0596, -0.2346,  ...,  0.6528, -0.0359, -0.0077]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.3462,  0.0217, -0.6968,  ...,  0.2649, -0.0319,  0.3379],
        [ 0.0986,  0.1879, -0.3506,  ...,  0.7334,  0.1807,  0.2191],
        [ 0.3142, -0.1401, -0.3792,  ...,  0.3479,  0.1425, -0.2242],
        [ 0.3530,  0.2605, -0.2700,  ...,  0.5024,  0.0731,  0.0203]],
     

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 3.1006e-02,  3.6328e-01, -2.8052e-01,  ...,  1.2002e+00,
         -9.3937e-04, -6.2683e-02],
        [ 1.0437e-01,  1.3928e-01, -4.8145e-01,  ...,  9.2480e-01,
         -1.1426e-01,  1.2842e-01],
        [ 4.8901e-01, -1.2341e-01, -4.6509e-01,  ...,  4.9243e-01,
          3.7012e-01,  4.4141e-01],
        [ 1.1469e-01,  1.0272e-01, -5.3613e-01,  ...,  1.0684e+00,
          1.5173e-01,  4.2610e-03]], device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [n

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.4993, -0.2537, -0.4104,  ...,  0.5884,  0.0975,  0.1915],
        [-0.0342,  0.0884, -0.3374,  ...,  0.5933, -0.0878,  0.4041],
        [-0.1227,  0.0654,  0.0189,  ...,  0.9404, -0.0526, -0.0732],
        [ 0.3074,  0.2439, -0.1794,  ...,  0.9888, -0.0156,  0.0352]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[-1.4913e-04, -1.6479e-02, -1.8213e-01,  ...,  5.3369e-01,
         -1.9226e-01,  1.1359e-01],
        [ 2.2168e-01, -7.1594e-02, -3.0518e-01,  ...,  5.2686e-01,
         -3.4058e-01, -8.1543e-02],
        [ 3.3643e-01,  7.0703e-01, -4.2627e-01,  ...,  6.0010e-01,
         -3.4637e-02,  2.9419e-01],
        [ 4.0381e-01,  1.5076e-01, -4.7998e-01,  ...,  1.0752e+00,
          1.5808e-01,  2.8223e-01]], device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Loss after 3400 Batch is nan 
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  .

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[-0.0430,  0.1117, -0.2539,  ...,  0.7461, -0.2357,  0.2051],
        [-0.0498,  0.0992, -0.3557,  ...,  0.3752, -0.0775,  0.2593],
        [ 0.4319,  0.3352, -0.0587,  ...,  0.6475,  0.0828,  0.2335],
        [-0.0300,  0.0517, -0.0939,  ...,  0.4839, -0.1628,  0.0106]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.4832,  0.0402, -0.1936,  ...,  0.4556, -0.1969, -0.1400],
        [ 0.2695,  0.0133, -0.4622,  ...,  0.6758, -0.2683,  0.0731],
        [ 0.2676,  0.3381, -0.3213,  ...,  0.7056, -0.1454,  0.2766],
        [ 0.0801, -0.1406, -0.6938,  ...,  0.3772, -0.1937,  0.5254]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.1250,  0.2502, -0.2067,  ...,  0.6177, -0.3755, -0.1117],
        [ 0.0640, -0.0043, -0.3127,  ...,  1.0127, -0.0734, -0.2404],
        [-0.0895,  0.0609, -0.4292,  ...,  0.5752,  0.0999, -0.0195],
        [ 0.5234, -0.0144, -0.4436,  ...,  0.5117,  0.0162,  0.0093]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.0636, -0.0687,  0.0150,  ...,  0.6411, -0.0040, -0.0473],
        [ 0.1910,  0.3188, -0.2236,  ...,  0.5859, -0.0379,  0.1606],
        [ 0.3127,  0.1409,  0.0178,  ...,  0.4343, -0.2317,  0.1063],
        [ 0.3025, -0.0391, -0.3340,  ...,  0.6396,  0.0861,  0.0850]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[-0.1223,  0.5098, -0.4624,  ...,  0.8184, -0.1722,  0.0654],
        [ 0.1215,  0.0691, -0.5918,  ...,  0.7056, -0.0113,  0.1096],
        [ 0.0319, -0.2263, -0.4087,  ...,  0.5542, -0.2485,  0.2588],
        [ 0.1770, -0.2135, -0.4114,  ...,  1.0498, -0.1898,  0.3538]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Loss after 3460 Batch is nan 
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBa

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.4504,  0.6074,  0.1061,  ...,  0.5117,  0.1770,  0.5703],
        [ 0.0306, -0.1035, -0.1740,  ...,  0.3350, -0.4001, -0.2917],
        [-0.0028,  0.0189, -0.1876,  ...,  0.5498, -0.0294,  0.2842],
        [ 0.0305, -0.1613, -0.2573,  ...,  0.4487, -0.4304, -0.0235]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.2001, -0.2246, -0.0567,  ...,  0.6899,  0.2136,  0.3430],
        [-0.1606,  0.3298, -0.3516,  ...,  0.5718,  0.1447, -0.0142],
        [-0.2859, -0.0818, -0.1970,  ...,  0.9927, -0.3979, -0.2168],
        [ 0.3838,  0.1659, -0.3381,  ...,  0.6812,  0.3557,  0.0241]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

outputs_teacher tensor([[ 0.3716, -0.2686, -0.2664,  ...,  0.9897,  0.0768,  0.2808],
        [ 0.0072, -0.2820, -0.4165,  ...,  0.6235, -0.4155,  0.2421],
        [ 0.0191, -0.0781, -0.3440,  ...,  0.8501,  0.1562,  0.2832],
        [ 0.0884, -0.1876, -0.2107,  ...,  0.6660, -0.1604,  0.1538]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.4060, -0.2676, -0.3364,  ...,  0.8589,  0.4016,  0.1499],
        [ 0.3113,  0.1429, -0.4148,  ...,  0.3630, -0.2416, -0.2043],
        [-0.0905,  0.0716, -0.2600,  ...,  0.5518,  0.0811, -0.1740],
        [ 0.0253,  0.0858, -0.3257,  ...,  0.5190, -0.0580,  0.0806]],
     

outputs_teacher tensor([[ 0.0872,  0.2419, -0.4519,  ...,  0.9199, -0.1954, -0.1099],
        [ 0.0312,  0.2966, -0.0418,  ...,  0.6646,  0.0802,  0.0671],
        [ 0.0588, -0.0771, -0.1904,  ...,  0.3708, -0.2625,  0.2328],
        [ 0.2152, -0.3059, -0.2090,  ...,  0.5171, -0.1519,  0.3052]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.5024,  0.2496, -0.1436,  ...,  0.4172,  0.0155, -0.0685],
        [ 0.2125,  0.0102, -0.4749,  ...,  0.4473, -0.0944,  0.1799],
        [ 0.3604,  0.0255, -0.2155,  ...,  0.7144,  0.0294,  0.2294],
        [ 0.1810, -0.1777, -0.5122,  ...,  1.0371, -0.1416,  0.4385]],
     

outputs_teacher tensor([[ 0.2424,  0.2350, -0.2839,  ...,  0.6030,  0.3877,  0.2186],
        [ 0.1301,  0.0169, -0.3950,  ...,  0.9136,  0.0599,  0.3809],
        [ 0.3083,  0.3748, -0.5083,  ...,  0.4624, -0.0027, -0.1779],
        [ 0.1688,  0.1626, -0.2908,  ...,  0.7881, -0.1372,  0.0726]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Loss after 3520 Batch is nan 
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.1103,  0.3870, -0.6558,  ...,  0.5186,  0.1527,  0.0693],
        [ 0.3191,  0.3743, -0.4265,  ...,  0.9033,  0.0428,  0.1059],
        [ 0.0216, -0.1119, -0.3550,  ...,  0.7568, -0.2324,  0.1790],
        [ 0.1759,  0.1410, -0.2571,  ...,  0.7

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 2.9639e-01,  3.4961e-01, -3.2568e-01,  ...,  6.5137e-01,
          1.1053e-01, -2.4277e-02],
        [ 1.5857e-01,  4.0747e-01,  6.4026e-02,  ...,  8.2031e-01,
          1.6980e-01,  5.7666e-01],
        [ 2.1167e-01, -1.6739e-02, -1.3074e-01,  ...,  7.8955e-01,
          4.8208e-04,  1.1407e-01],
        [ 1.8921e-01, -1.8726e-01, -2.6196e-01,  ...,  7.0801e-01,
         -2.8394e-01,  4.4312e-01]], device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [n

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.0212,  0.1746, -0.2878,  ...,  1.0869, -0.0441, -0.0769],
        [ 0.0880, -0.0688, -0.1743,  ...,  0.7017, -0.1085,  0.0817],
        [ 0.0905, -0.0341, -0.5386,  ...,  0.9282,  0.2109,  0.1376],
        [ 0.2174, -0.1194,  0.1616,  ...,  0.2561, -0.2881, -0.0545]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Loss after 3550 Batch is nan 
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBa

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.0597,  0.1580, -0.3560,  ...,  0.4570, -0.1093,  0.3027],
        [ 0.0230,  0.2510, -0.2856,  ...,  0.5688, -0.1197,  0.1222],
        [-0.0505,  0.0368, -0.5039,  ...,  0.5117,  0.1925, -0.2532],
        [ 0.2257,  0.3816, -0.2629,  ...,  0.8481,  0.3174, -0.1119]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.0744, -0.2729, -0.2065,  ...,  0.6396, -0.1265,  0.4622],
        [ 0.0837,  0.0649, -0.3020,  ...,  0.7754,  0.0349, -0.0930],
        [-0.0422,  0.0495, -0.2222,  ...,  0.6226,  0.0425, -0.1047],
        [ 0.1851,  0.2336, -0.1803,  ...,  0.6309, -0.0670,  0.2644]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.4636,  0.1193, -0.1157,  ...,  0.6162, -0.1544,  0.3660],
        [ 0.3796,  0.2496, -0.3394,  ...,  0.7900, -0.1711, -0.0109],
        [ 0.1805,  0.2200, -0.2373,  ...,  0.8589,  0.0687,  0.1725],
        [ 0.2612,  0.3059, -0.2620,  ...,  0.5234,  0.0037, -0.0346]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[-0.0046, -0.1909, -0.0959,  ...,  0.4346, -0.0464, -0.0420],
        [-0.0064,  0.0660, -0.4016,  ...,  0.8438, -0.2271, -0.0148],
        [ 0.2668,  0.1509, -0.3130,  ...,  0.4717, -0.0903,  0.1956],
        [ 0.2145,  0.3071,  0.0345,  ...,  0.3784,  0.0616, -0.2069]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[-0.1193,  0.3979,  0.0377,  ...,  0.9238, -0.2715,  0.3630],
        [ 0.6084,  0.2441, -0.2163,  ...,  0.8584,  0.0759, -0.5093],
        [-0.1804, -0.1643, -0.3037,  ...,  0.7988,  0.0490, -0.2551],
        [ 0.3264,  0.0260, -0.3569,  ...,  0.4670,  0.0574,  0.2382]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Loss after 3610 Batch is nan 
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBa

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.2397,  0.2271, -0.0834,  ...,  0.4490, -0.1333,  0.1610],
        [ 0.1299,  0.2257, -0.0575,  ...,  0.9355, -0.1133, -0.0759],
        [ 0.1683,  0.4287, -0.2426,  ...,  0.5591,  0.0659,  0.5361],
        [ 0.1512, -0.0097, -0.3508,  ...,  0.9346, -0.1060,  0.0720]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.0202,  0.3552, -0.3320,  ...,  0.7891, -0.1241, -0.1660],
        [ 0.1201, -0.3120, -0.1328,  ...,  0.7144,  0.2554,  0.1077],
        [ 0.0047,  0.2356, -0.1044,  ...,  0.6304, -0.1661, -0.0841],
        [ 0.3391,  0.3408, -0.6831,  ...,  0.5532,  0.1367,  0.2556]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.3992,  0.1560, -0.2998,  ...,  0.6509, -0.2715,  0.2917],
        [ 0.3218,  0.2607, -0.3228,  ...,  0.4473,  0.2104,  0.1643],
        [ 0.0777,  0.1758, -0.2988,  ...,  0.6299, -0.5024,  0.3162],
        [ 0.3855,  0.2622, -0.2573,  ...,  0.9282,  0.3013,  0.0318]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.3179,  0.1862, -0.1203,  ...,  0.7007, -0.0851, -0.0831],
        [ 0.6318,  0.1847, -0.3230,  ...,  0.6958, -0.2317,  0.1520],
        [ 0.3992,  0.0944, -0.4165,  ...,  0.7476, -0.0636,  0.1970],
        [ 0.0410,  0.0070, -0.0273,  ...,  0.9072, -0.3110,  0.0284]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[-0.2124,  0.0833, -0.0600,  ...,  0.8340, -0.1666, -0.0695],
        [ 0.1874,  0.0674, -0.1801,  ...,  0.7007,  0.2322, -0.0971],
        [ 0.4055,  0.2123, -0.2505,  ...,  0.6948,  0.3035, -0.0677],
        [ 0.0844,  0.6304, -0.2413,  ...,  1.0947, -0.0825, -0.1379]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[-0.0357,

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.1229,  0.3948, -0.3328,  ...,  0.8306,  0.0575, -0.0034],
        [ 0.1014,  0.2612, -0.2283,  ...,  0.8599,  0.0346, -0.0850],
        [ 0.2019,  0.0488, -0.4011,  ...,  0.6226, -0.0637,  0.1852],
        [ 0.1345,  0.0339, -0.4258,  ...,  0.9326,  0.1714, -0.1547]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.2286, -0.0425, -0.1572,  ...,  0.3777, -0.1309,  0.2888],
        [ 0.1114,  0.3369, -0.4404,  ...,  0.3330, -0.2474, -0.3010],
        [ 0.2795, -0.2074, -0.1614,  ...,  0.6689, -0.0688, -0.1085],
        [ 0.3108, -0.0916, -0.3816,  ...,  0.4131,  0.0015,  0.0600]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[-0.0999,  0.4299,  0.0936,  ...,  0.5171, -0.2496,  0.0096],
        [ 0.1842,  0.0966, -0.3044,  ...,  0.5854,  0.1226,  0.1987],
        [ 0.2598,  0.0010,  0.0392,  ...,  0.5811, -0.3044, -0.0973],
        [ 0.0158,  0.0183, -0.2086,  ...,  0.6631, -0.2883,  0.2910]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Loss after 3720 Batch is nan 
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.4150,  0.2291, -0.5298,  ...,  0.9346, -0.0630,  0.0667],
        [ 0.2859, -0.0060, -0.0096,  ...,  0.6152,  0.3364, -0.1512],
        [ 0.2764, -0.0205, -0.1216,  ...,  0.6235, -0.3269,  0.0920],
        [ 0.6260,  0.1188, -0.2211,  ...,  0.5234,  0.1428,  0.1200]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBa

outputs_teacher tensor([[ 0.0115, -0.1158, -0.2256,  ...,  0.7593, -0.0830,  0.1011],
        [ 0.3787,  0.2539, -0.5288,  ...,  0.4062,  0.7236, -0.0729],
        [-0.0717, -0.1042, -0.3347,  ...,  1.1094, -0.0029, -0.0662],
        [ 0.4094,  0.4856, -0.3840,  ...,  0.9272,  0.1362,  0.2698]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.1791, -0.1429, -0.1787,  ...,  0.4529, -0.1152, -0.1824],
        [ 0.0215,  0.3149, -0.4048,  ...,  0.8950, -0.0916,  0.0137],
        [ 0.0429,  0.2395, -0.2279,  ...,  0.8164,  0.2849, -0.1758],
        [-0.0305,  0.1447, -0.2549,  ...,  0.7480, -0.2305,  0.5151]],
     

tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[-0.1492,  0.2622, -0.2815,  ...,  0.6416,  0.1893,  0.0461],
        [ 0.2898,  0.3989, -0.0854,  ...,  0.7085, -0.2285,  0.1138],
        [ 0.2250,  0.1542, -0.4426,  ...,  0.6885,  0.4409,  0.1660],
        [ 0.5781,  0.0153, -0.2913,  ...,  0.8496, -0.0236,  0.0767]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[-0.0394,  0.5029, -0.5024,  ...,  0.6978, -0.1109,  0.4294],
        [ 0.0985,  0.2006, -0.4634,  ...,  0.4373, -0.0981,  0.2539],
        [ 0.2079, -0.2605, -0.2135,  ...,  0.7852,  0.1332, -0.0919],
        [ 0.2671,  0.1838, -0.2256,  ...,  0.8110, -0.1476,  0.0209]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Loss after 3760 Batch is nan 
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBa

outputs_teacher tensor([[ 0.1659,  0.1655, -0.4011,  ...,  0.8188, -0.0612,  0.4746],
        [ 0.1843, -0.3037, -0.7803,  ...,  0.7202,  0.2498,  0.5317],
        [ 0.3293,  0.0926, -0.4341,  ...,  0.6572,  0.0277,  0.2024],
        [ 0.1006,  0.1639, -0.3042,  ...,  0.5728, -0.0499,  0.0587]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[-0.1252,  0.0893, -0.1089,  ...,  0.5576, -0.1222, -0.2612],
        [ 0.1635,  0.2340, -0.4414,  ...,  0.7046,  0.0941,  0.0601],
        [ 0.2407, -0.1110, -0.1991,  ...,  0.8892, -0.2142, -0.0059],
        [ 0.1076,  0.3420, -0.3772,  ...,  0.8501,  0.0352,  0.2754]],
     

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.1631,  0.0839, -0.2302,  ...,  1.0273, -0.1708,  0.2194],
        [ 0.1940, -0.3972, -0.4790,  ...,  0.5542, -0.5020,  0.1050],
        [-0.0622,  0.2296, -0.1962,  ...,  1.0010,  0.0363,  0.2335],
        [ 0.1946, -0.2365, -0.0554,  ...,  0.8984,  0.2086,  0.2732]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.1431,  0.0701, -0.2639,  ...,  0.9019, -0.2175, -0.0658],
        [-0.0829, -0.2209, -0.2126,  ...,  0.5059, -0.1196,  0.0839],
        [ 0.3315, -0.0224, -0.3411,  ...,  0.6797, -0.0580,  0.5020],
        [ 0.3330,  0.2805, -0.2617,  ...,  0.6934, -0.0955,  0.2969]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.4319,

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.1583,  0.5044, -0.1682,  ...,  0.7212,  0.2927, -0.2090],
        [-0.1561,  0.0378, -0.1748,  ...,  1.1514, -0.1542, -0.1283],
        [ 0.3169, -0.1163, -0.4368,  ...,  0.6328,  0.1455,  0.1973],
        [-0.0479,  0.1267, -0.2595,  ...,  0.7866, -0.1490, -0.0994]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[-0.1157,  0.2322, -0.1597,  ...,  0.9141, -0.0759, -0.1831],
        [ 0.4944,  0.2432, -0.4717,  ...,  1.0791,  0.2854,  0.1189],
        [ 0.3926,  0.1713, -0.4343,  ...,  0.6528,  0.1855, -0.0057],
        [ 0.2844, -0.0101, -0.2137,  ...,  0.8735,  0.0360,  0.2703]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.0910,  0.0146, -0.4358,  ...,  0.7617, -0.1198,  0.3291],
        [ 0.1907,  0.3108, -0.6279,  ...,  0.7441, -0.0076, -0.0436],
        [ 0.1472,  0.1866, -0.4111,  ...,  0.7764, -0.1080,  0.2476],
        [ 0.4700,  0.3442, -0.3574,  ...,  0.4126,  0.0519,  0.1492]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.2296,  0.2024, -0.3914,  ...,  0.9497,  0.1686, -0.2451],
        [ 0.3289,  0.1252, -0.3901,  ...,  0.7690, -0.0129,  0.0118],
        [-0.0331, -0.0024, -0.1599,  ...,  0.9385, -0.1322,  0.1104],
        [ 0.3979, -0.3003, -0.2395,  ...,  0.6753,  0.3030,  0.0989]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.3164, -0.0826, -0.1027,  ...,  0.4983, -0.3201,  0.0325],
        [ 0.1854,  0.4353, -0.3848,  ...,  0.2869, -0.0168, -0.1681],
        [ 0.1332,  0.0713, -0.5933,  ...,  0.8970, -0.1523,  0.0646],
        [ 0.0608,  0.1558, -0.1212,  ...,  0.4861, -0.3889,  0.1134]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.1920,  0.2764, -0.5522,  ...,  0.6191,  0.0438,  0.2830],
        [ 0.2076, -0.1346, -0.3245,  ...,  0.5522,  0.3430,  0.3093],
        [-0.0767,  0.1074, -0.1212,  ...,  0.7676, -0.0892, -0.1128],
        [ 0.3772, -0.1754, -0.2976,  ...,  0.5635, -0.1663,  0.1737]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.1973, -0.1191, -0.2380,  ...,  0.6768, -0.2355,  0.2361],
        [-0.0649,  0.1224, -0.4753,  ...,  0.9604, -0.0616,  0.1754],
        [ 0.5698, -0.1300,  0.2352,  ...,  0.4043,  0.0891,  0.0529],
        [ 0.3318, -0.0066, -0.3315,  ...,  0.6118, -0.0840,  0.0677]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.3035,  0.3621,  0.0191,  ...,  0.8184,  0.2214,  0.0317],
        [-0.0953,  0.1173, -0.2244,  ...,  0.8208,  0.0797, -0.0100],
        [ 0.1935, -0.0374, -0.4785,  ...,  0.6260, -0.0649,  0.2247],
        [ 0.4460, -0.0443, -0.3855,  ...,  0.2500, -0.3230,  0.2050]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[-0.0459,  0.1191, -0.2039,  ...,  0.9424,  0.0231, -0.2029],
        [ 0.3120,  0.3362, -0.2698,  ...,  0.5698,  0.0807, -0.1595],
        [-0.0137,  0.1757, -0.1936,  ...,  0.6953, -0.0651,  0.5376],
        [ 0.1843,  0.4333, -0.2754,  ...,  1.0674, -0.3647,  0.0714]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.0446,  0.1036, -0.1848,  ...,  0.4102,  0.1008,  0.2115],
        [ 0.1138, -0.1825, -0.2378,  ...,  0.6274, -0.0227, -0.1039],
        [ 0.1039,  0.1921, -0.1926,  ...,  0.8423,  0.3472, -0.2408],
        [-0.1998,  0.1278, -0.0330,  ...,  0.8242, -0.0109,  0.1632]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.2354,  0.2054, -0.1168,  ...,  0.8970, -0.0996,  0.4219],
        [ 0.3665,  0.0827, -0.2262,  ...,  0.6919,  0.0956,  0.0060],
        [ 0.3120,  0.1342, -0.4209,  ...,  0.5405, -0.0611,  0.4121],
        [-0.0557, -0.1415, -0.1764,  ...,  0.8008, -0.1722,  0.1390]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.0875, -0.2905, -0.3833,  ...,  1.0273, -0.1718,  0.3992],
        [ 0.0991, -0.1749, -0.1680,  ...,  0.7563, -0.1816,  0.2150],
        [ 0.4924,  0.1058, -0.2367,  ...,  0.8823, -0.0403, -0.0301],
        [ 0.3022,  0.0208, -0.0293,  ...,  0.5488,  0.2507,  0.0574]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.2949,  0.1238, -0.2769,  ...,  1.0430,  0.3103,  0.1349],
        [-0.1063,  0.2725, -0.1769,  ...,  0.7031, -0.0312, -0.0355],
        [ 0.1976, -0.0518, -0.3235,  ...,  0.5239, -0.3362, -0.2095],
        [ 0.1897,  0.0398, -0.2644,  ...,  0.8589, -0.1215,  0.1338]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.1281,  0.0905, -0.3628,  ...,  0.5986, -0.0851,  0.0375],
        [ 0.6099, -0.2267, -0.1604,  ...,  0.5898, -0.5190,  0.2198],
        [ 0.2456,  0.1460, -0.6807,  ...,  0.9316,  0.1605,  0.0670],
        [ 0.3115,  0.2812, -0.1453,  ...,  0.6387, -0.2083,  0.1443]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.0452,  0.0778, -0.3381,  ...,  0.4412,  0.2505, -0.0477],
        [ 0.0020, -0.0977, -0.3386,  ...,  0.4766, -0.0608,  0.2925],
        [ 0.3850, -0.2235, -0.4761,  ...,  0.9033,  0.0526, -0.3594],
        [ 0.2499,  0.3218, -0.3950,  ...,  0.2686, -0.0324,  0.1710]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[-3.5889e-02, -1.8225e-01, -2.7002e-01,  ...,  3.7646e-01,
         -8.5449e-02,  4.2041e-01],
        [-2.5317e-01,  1.4397e-02, -4.6899e-01,  ...,  9.0625e-01,
         -3.5449e-01, -2.7359e-02],
        [ 5.2295e-01,  3.3545e-01, -3.5132e-01,  ...,  8.5889e-01,
         -1.6638e-01,  1.9666e-01],
        [-2.1684e-04,  3.5571e-01, -4.2993e-01,  ...,  9.8730e-01,
         -2.6294e-01,  6.2286e-02]], device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Loss after 3990 Batch is nan 
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  .

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[-0.0755,  0.2507, -0.1373,  ...,  0.7139, -0.1367, -0.4490],
        [ 0.4897, -0.0845, -0.3086,  ...,  0.5371, -0.0202,  0.0185],
        [-0.1105, -0.0984, -0.2386,  ...,  0.5762, -0.1498,  0.2920],
        [-0.0212, -0.0717, -0.0882,  ...,  0.6230,  0.0402,  0.2383]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[-0.2795,  0.1243, -0.3069,  ...,  0.5278, -0.2764,  0.0034],
        [ 0.3328,  0.0153, -0.2094,  ...,  0.3342, -0.3035,  0.1841],
        [ 0.1494,  0.2369, -0.3706,  ...,  0.4854, -0.0518, -0.1186],
        [ 0.1649, -0.0637, -0.3806,  ...,  0.5200,  0.1993,  0.2722]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[-0.1301,  0.2104, -0.0374,  ...,  0.9805, -0.2771, -0.0777],
        [ 0.1859,  0.3242, -0.3333,  ...,  1.1904, -0.0401,  0.3762],
        [ 0.1107, -0.2059,  0.3796,  ...,  0.6416,  0.0579, -0.2354],
        [ 0.2703,  0.3054, -0.3433,  ...,  0.8745,  0.2390,  0.1370]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 2.9541e

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.2798,  0.1172, -0.5200,  ...,  0.5269,  0.2969, -0.0825],
        [-0.0014,  0.3398, -0.2805,  ...,  1.1865,  0.1169,  0.1012],
        [ 0.3979,  0.4380, -0.2466,  ...,  0.7280, -0.0504,  0.1753],
        [-0.1052, -0.1743, -0.1014,  ...,  0.4836,  0.0956,  0.1120]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Loss after 4040 Batch is nan 
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBa

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.4043,  0.1536, -0.3083,  ...,  0.6997, -0.4243,  0.2620],
        [ 0.1065,  0.0235, -0.4663,  ...,  0.8789,  0.0937, -0.1100],
        [ 0.0142, -0.1458, -0.3066,  ...,  1.1436,  0.0056,  0.3933],
        [ 0.3081,  0.5654, -0.2661,  ...,  0.4944,  0.0312, -0.4001]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.4990,  0.0348, -0.3086,  ...,  0.5947,  0.1710,  0.0958],
        [ 0.3271,  0.0850, -0.1092,  ...,  0.6191,  0.0514, -0.2050],
        [ 0.2871, -0.0322,  0.0227,  ...,  0.5894,  0.3103, -0.0202],
        [ 0.4255,  0.5029, -0.2056,  ...,  0.4487, -0.0334, -0.1876]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.3735, -0.0142, -0.3394,  ...,  0.8354, -0.2727,  0.2922],
        [ 0.1575,  0.4763, -0.2874,  ...,  0.7388, -0.0830,  0.0131],
        [ 0.1450, -0.0255, -0.2651,  ...,  0.5474,  0.0665,  0.4851],
        [-0.0754,  0.2705, -0.2153,  ...,  1.0596, -0.3054,  0.3696]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.4170,  0.1864, -0.4438,  ...,  0.8057,  0.2471, -0.1544],
        [ 0.4978,  0.2085, -0.1116,  ...,  0.9985, -0.1342,  0.0917],
        [ 0.4507,  0.0861, -0.4504,  ...,  0.6270,  0.0724,  0.2170],
        [ 0.3855,  0.1328, -0.2927,  ...,  0.4524, -0.3201,  0.5083]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[-0.0175, -0.0920, -0.5034,  ...,  0.6265, -0.0644, -0.0539],
        [-0.0588,  0.1737, -0.1936,  ...,  0.9409, -0.1805, -0.0758],
        [ 0.3076,  0.0100, -0.4351,  ...,  0.6064, -0.0732,  0.2085],
        [ 0.3623,  0.2369, -0.4775,  ...,  0.4287, -0.1760,  0.2803]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Loss after 4100 Batch is nan 
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBa

outputs_teacher tensor([[ 0.6113,  0.0034, -0.4329,  ...,  0.2103,  0.0524,  0.0448],
        [-0.0073,  0.2411, -0.6377,  ...,  0.5615, -0.1296,  0.3364],
        [ 0.1043,  0.0499, -0.2234,  ...,  0.7476, -0.0715,  0.0920],
        [ 0.2455,  0.0592, -0.2861,  ...,  0.8027, -0.0044,  0.5264]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.2883,  0.1138, -0.2316,  ...,  0.7969,  0.0666,  0.0476],
        [ 0.1208, -0.3806, -0.5610,  ...,  0.9414,  0.1803, -0.1738],
        [ 0.2896,  0.0517, -0.1737,  ...,  0.6807, -0.1943,  0.1246],
        [-0.0598,  0.2100, -0.2554,  ...,  0.5923, -0.3813,  0.3069]],
     

outputs_teacher tensor([[-0.1158,  0.1372,  0.0654,  ...,  0.8589,  0.0225, -0.0210],
        [ 0.3672,  0.0977, -0.3376,  ...,  0.7305, -0.2467, -0.1768],
        [ 0.0445,  0.0500, -0.4221,  ...,  0.7676,  0.3770,  0.0309],
        [-0.0300,  0.2578, -0.4058,  ...,  0.5312, -0.1621, -0.2281]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.2957,  0.0113, -0.4753,  ...,  0.7373, -0.1611,  0.1125],
        [ 0.1854, -0.0737, -0.1481,  ...,  0.3391,  0.1006,  0.3789],
        [ 0.0397,  0.1584,  0.0618,  ...,  0.5688, -0.2783, -0.1774],
        [ 0.3025,  0.2009, -0.2354,  ...,  0.8696, -0.1076,  0.0359]],
     

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[-2.1072e-02, -6.7810e-02, -7.8613e-02,  ...,  7.0703e-01,
         -3.0835e-01,  1.6342e-02],
        [-1.6663e-01, -3.7140e-02, -4.5532e-01,  ...,  4.1479e-01,
         -5.6610e-02, -6.4880e-02],
        [ 5.8031e-04, -2.1558e-01, -3.4937e-01,  ...,  8.7744e-01,
          8.1177e-02, -8.0933e-02],
        [ 2.4765e-02,  2.5024e-01, -3.5156e-01,  ...,  5.0586e-01,
          9.4528e-03,  1.4868e-01]], device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [n

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[-0.2169,  0.1222, -0.1478,  ...,  0.6797, -0.2627, -0.0885],
        [ 0.5757,  0.1039, -0.3923,  ...,  0.7095, -0.1504,  0.5391],
        [ 0.2522, -0.0164, -0.2820,  ...,  0.7681,  0.1575, -0.0272],
        [-0.2612,  0.0925,  0.0256,  ...,  1.3096,  0.0022, -0.0142]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Loss after 4150 Batch is nan 
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBa

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[-0.0666,  0.1818, -0.2969,  ...,  0.5713, -0.0989, -0.3479],
        [ 0.4390,  0.3342, -0.3499,  ...,  0.8789, -0.1146, -0.0103],
        [ 0.2288,  0.0892, -0.4182,  ...,  0.6362, -0.2222,  0.0431],
        [-0.0349,  0.0666, -0.3433,  ...,  0.7515, -0.1888, -0.1111]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.5332,  0.4636, -0.3191,  ...,  0.6519,  0.1899, -0.1028],
        [ 0.3647, -0.0686,  0.0321,  ...,  0.5786,  0.3911,  0.0453],
        [ 0.4712,  0.0511, -0.3066,  ...,  0.3672, -0.0660,  0.0314],
        [ 0.0240, -0.0261, -0.0864,  ...,  0.6626, -0.0271,  0.0546]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.1169, -0.0137, -0.3997,  ...,  0.7773, -0.1122,  0.0857],
        [ 0.4602,  0.4392, -0.5103,  ...,  0.9321, -0.0125,  0.0931],
        [ 0.0590, -0.2148, -0.3838,  ...,  0.5063,  0.2396,  0.3130],
        [ 0.3877,  0.2771, -0.3079,  ...,  0.5820,  0.0548,  0.2325]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.1368,  0.0988, -0.0459,  ...,  0.7231, -0.1118, -0.1478],
        [ 0.3159,  0.0673, -0.4763,  ...,  0.8003,  0.0092,  0.3091],
        [ 0.2429,  0.2368, -0.3779,  ...,  0.8125, -0.2290,  0.0660],
        [ 0.1628,  0.2054, -0.2659,  ...,  0.5059, -0.1572,  0.5205]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

outputs_teacher tensor([[-0.1653,  0.0229, -0.3955,  ...,  0.6143,  0.2603,  0.0461],
        [ 0.3992,  0.1899, -0.4949,  ...,  0.9634,  0.2205,  0.2152],
        [ 0.1630, -0.0876, -0.3828,  ...,  0.9150,  0.0114,  0.3074],
        [ 0.1379, -0.1815, -0.2018,  ...,  0.9414, -0.0425,  0.0656]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Loss after 4210 Batch is nan 
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[-0.3611,  0.0776, -0.1700,  ...,  0.5347,  0.1335, -0.1991],
        [ 0.2634,  0.1346, -0.1755,  ...,  0.1989, -0.2673,  0.2690],
        [ 0.2244,  0.3887, -0.0061,  ...,  0.4238, -0.2954,  0.1550],
        [ 0.1384,  0.3306, -0.5732,  ...,  0.7

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[-0.2034,  0.0757, -0.2146,  ...,  0.8291,  0.0765,  0.3159],
        [ 0.4910, -0.0970, -0.3306,  ...,  0.4895,  0.2467, -0.0236],
        [ 0.3569, -0.0373, -0.4539,  ...,  0.9141,  0.1037, -0.0518],
        [ 0.2357,  0.1588, -0.4495,  ...,  0.8599, -0.1373,  0.1426]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.1111,  0.2969, -0.4026,  ...,  0.7334,  0.1852, -0.0478],
        [-0.1835, -0.0699, -0.2358,  ...,  1.0469, -0.1693, -0.0955],
        [-0.2793,  0.1339, -0.0649,  ...,  0.5742, -0.0433, -0.1302],
        [ 0.3064,  0.2124, -0.4385,  ...,  0.8833, -0.0059, -0.0075]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Loss after 4240 Batch is nan 
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBa

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.1022, -0.1267, -0.5054,  ...,  1.1895,  0.0129, -0.1201],
        [-0.0129,  0.1082, -0.2983,  ...,  1.0488, -0.1050,  0.1908],
        [ 0.3315,  0.0944, -0.2632,  ...,  0.7383,  0.0125, -0.0632],
        [ 0.3015, -0.1295, -0.6206,  ...,  0.8726, -0.0627,  0.2186]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.5317, -0.2290, -0.0608,  ...,  0.6660,  0.4646,  0.0235],
        [ 0.3098,  0.0548, -0.1276,  ...,  0.6812, -0.1417,  0.0755],
        [ 0.2424, -0.0020, -0.3977,  ...,  0.5127, -0.2385, -0.2010],
        [ 0.2532, -0.1565, -0.1260,  ...,  0.5649,  0.0356,  0.2115]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.2527,  0.1180, -0.3718,  ...,  0.7559,  0.0160,  0.1984],
        [ 0.0840,  0.1874, -0.3372,  ...,  0.6167, -0.2659,  0.3123],
        [ 0.2019, -0.2067, -0.5405,  ...,  1.4648, -0.0984, -0.1626],
        [ 0.2455,  0.2051, -0.3884,  ...,  0.6367, -0.2131,  0.2522]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.2666,  0.2886, -0.4084,  ...,  0.9360,  0.1234, -0.0726],
        [ 0.3730,  0.0151, -0.1499,  ...,  0.7041, -0.1331, -0.0417],
        [ 0.0573,  0.1368, -0.3499,  ...,  0.4822, -0.1045,  0.2654],
        [ 0.0906, -0.0522, -0.2500,  ...,  0.8564, -0.3057,  0.0915]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.0305, -0.1000,  0.1042,  ...,  0.5977,  0.2476, -0.0114],
        [ 0.2834,  0.1787, -0.0779,  ...,  0.6343, -0.1842, -0.0768],
        [ 0.3030,  0.1827, -0.1581,  ...,  0.7837, -0.2445, -0.1133],
        [ 0.5005,  0.2308, -0.3374,  ...,  0.7988, -0.1887,  0.2847]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Loss after 4300 Batch is nan 
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBa

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.3179,  0.0014, -0.2842,  ...,  0.5317, -0.2142, -0.0858],
        [ 0.3760,  0.1812, -0.5093,  ...,  0.7388, -0.1103,  0.0286],
        [ 0.3357,  0.3330, -0.2961,  ...,  0.4792,  0.1431,  0.1106],
        [ 0.0776,  0.4258, -0.3523,  ...,  0.5488, -0.0284, -0.2358]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.1952, -0.1934, -0.2164,  ...,  0.6680,  0.0292,  0.0831],
        [ 0.4944,  0.3328, -0.0841,  ...,  1.1201,  0.2595,  0.3142],
        [ 0.3684,  0.0021, -0.6167,  ...,  0.6987,  0.0964,  0.0122],
        [ 0.1237, -0.0751, -0.1445,  ...,  0.2817,  0.2074,  0.0823]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.0413,  0.1439, -0.1578,  ...,  0.7095,  0.0695, -0.1785],
        [ 0.4043,  0.2883, -0.4128,  ...,  0.4956,  0.1133,  0.1254],
        [ 0.0975,  0.1261, -0.3877,  ...,  1.0439,  0.1652, -0.0425],
        [-0.0492,  0.2527, -0.0121,  ...,  0.5415,  0.0867, -0.0281]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.0532,

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.0934, -0.0213, -0.2959,  ...,  0.8389, -0.0336,  0.2228],
        [ 0.2507, -0.1622, -0.0119,  ...,  0.6499, -0.1741,  0.0978],
        [ 0.3579,  0.2297, -0.2448,  ...,  0.8188, -0.1368,  0.1854],
        [ 0.1462,  0.2484, -0.5771,  ...,  0.7500, -0.2651,  0.0385]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[-0.0076, -0.2656, -0.2107,  ...,  0.6367,  0.0059,  0.1647],
        [ 0.4888,  0.1470, -0.2451,  ...,  0.1632,  0.1360, -0.0301],
        [ 0.1890,  0.3445, -0.4602,  ...,  0.7983, -0.0554,  0.0066],
        [ 0.3865, -0.0809, -0.4360,  ...,  0.8423, -0.1078,  0.0519]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Loss after 4360 Batch is nan 
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBa

outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.3352,  0.0529, -0.3506,  ...,  0.8613,  0.0386, -0.0092],
        [ 0.3384,  0.0628, -0.2676,  ...,  0.6074,  0.2440, -0.1029],
        [ 0.0080,  0.1119, -0.3730,  ...,  1.0840,  0.2239,  0.4651],
        [ 0.3069, -0.3777, -0.5024,  ...,  0.7544, -0.0257,  0.0264]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.1155,

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.2661,  0.2856, -0.2917,  ...,  0.6128,  0.0360, -0.1462],
        [ 0.4280, -0.0940, -0.3179,  ...,  0.9287, -0.1760,  0.2231],
        [ 0.2871,  0.2063, -0.3625,  ...,  0.6104, -0.0164,  0.1759],
        [ 0.4316,  0.1748, -0.3401,  ...,  0.5093, -0.0673, -0.2568]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

outputs_teacher tensor([[ 0.1232,  0.1653, -0.2542,  ...,  1.0137, -0.2274, -0.3579],
        [ 0.2847, -0.0628, -0.4258,  ...,  0.8643,  0.0413, -0.0033],
        [ 0.3899,  0.3057, -0.4756,  ...,  0.9443,  0.1125,  0.0784],
        [ 0.2002, -0.0460, -0.1376,  ...,  0.9819, -0.1545,  0.0362]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.1625, -0.2983, -0.4944,  ...,  0.9800, -0.2219,  0.2590],
        [ 0.5464,  0.5264, -0.2080,  ...,  0.5020, -0.2386, -0.3274],
        [-0.0410,  0.0684, -0.2202,  ...,  0.7202, -0.0023,  0.0102],
        [ 0.2788,  0.3367, -0.1936,  ...,  0.9209, -0.0050, -0.1588]],
     

outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.0600, -0.0269, -0.1683,  ...,  0.8081, -0.2517, -0.0163],
        [ 0.4121, -0.0080, -0.7095,  ...,  0.4204,  0.0482, -0.0080],
        [ 0.1192,  0.4790, -0.1084,  ...,  0.5503, -0.0495, -0.3330],
        [ 0.3916,  0.1674, -0.4148,  ...,  0.9487,  0.0101,  0.3525]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.4243,

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.3267, -0.0500, -0.2341,  ...,  0.5161,  0.0823,  0.0791],
        [ 0.0083, -0.2120, -0.2362,  ...,  0.4060, -0.1256,  0.3464],
        [ 0.3989,  0.3181, -0.4597,  ...,  0.7593,  0.0379,  0.2234],
        [-0.0746, -0.1140, -0.3865,  ...,  0.6992,  0.1321,  0.3215]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[-0.0499,  0.6162, -0.4812,  ...,  0.8828, -0.0476,  0.0338],
        [ 0.6392, -0.3870, -0.3887,  ...,  0.7349,  0.2859,  0.3762],
        [ 0.4224,  0.1683, -0.3398,  ...,  0.6904,  0.1506,  0.2217],
        [-0.0226, -0.0367, -0.2107,  ...,  0.8218, -0.4197,  0.0365]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

outputs_teacher tensor([[ 0.5640,  0.1483, -0.0978,  ...,  0.6694, -0.1709, -0.2408],
        [-0.1163, -0.1165, -0.1185,  ...,  0.7759,  0.0925, -0.1009],
        [-0.0664,  0.2812, -0.4766,  ...,  0.6484, -0.1321, -0.3743],
        [ 0.5498,  0.1368, -0.0188,  ...,  0.5498, -0.2659, -0.0460]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.0080,  0.2183, -0.0711,  ...,  0.8682, -0.0035, -0.0820],
        [ 0.3271, -0.0833, -0.1689,  ...,  0.5122,  0.0092, -0.2439],
        [ 0.0747,  0.1176, -0.3164,  ...,  0.5337,  0.1389,  0.1440],
        [ 0.3564,  0.0389, -0.5757,  ...,  0.4031,  0.2279,  0.2329]],
     

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.8384,  0.2698,  0.1439,  ...,  0.6440,  0.2411,  0.0912],
        [ 0.2734,  0.3245, -0.2166,  ...,  0.5820, -0.0296,  0.0808],
        [ 0.2229,  0.1676, -0.4128,  ...,  1.0508, -0.1962,  0.1021],
        [ 0.0634,  0.1831, -0.2013,  ...,  0.5283,  0.0783,  0.2927]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.1533, -0.3235, -0.3596,  ...,  0.4919, -0.3816,  0.3203],
        [-0.1802, -0.0112,  0.0545,  ...,  0.7549, -0.3022,  0.3269],
        [ 0.3164,  0.1951, -0.2944,  ...,  0.6650, -0.1907, -0.1047],
        [ 0.4861,  0.2949, -0.5166,  ...,  0.7939, -0.0321, -0.0555]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.1333,  0.4375, -0.1528,  ...,  0.7090,  0.1043, -0.1770],
        [ 0.1422,  0.0875, -0.5151,  ...,  0.3152,  0.0042,  0.3323],
        [ 0.3381, -0.2156, -0.0167,  ...,  0.5312,  0.3870,  0.4204],
        [ 0.2715,  0.2443, -0.4016,  ...,  0.5229, -0.1390,  0.0929]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.0641,  0.1451, -0.3115,  ...,  0.3076, -0.2952,  0.2844],
        [ 0.2432, -0.3308, -0.0470,  ...,  0.6255,  0.1926, -0.1809],
        [-0.0085, -0.0542, -0.2336,  ...,  0.9365, -0.2345,  0.2678],
        [ 0.1108, -0.1099, -0.4343,  ...,  0.9150, -0.2131, -0.2502]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

outputs_teacher tensor([[ 0.5933,  0.1753, -0.4617,  ...,  0.8931,  0.3088,  0.2144],
        [ 0.4175,  0.0714,  0.0028,  ...,  0.8794,  0.1666, -0.2257],
        [-0.1083, -0.1567, -0.3931,  ...,  0.9893, -0.0113,  0.3362],
        [ 0.4001,  0.1788, -0.3203,  ...,  0.6396,  0.0667, -0.1035]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Loss after 4510 Batch is nan 
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.3491, -0.2314, -0.2578,  ...,  0.6143,  0.0384,  0.1262],
        [ 0.6553,  0.1097, -0.4614,  ...,  0.4797, -0.0671,  0.3833],
        [ 0.2561,  0.1426, -0.3872,  ...,  1.2422, -0.2690,  0.1046],
        [ 0.4133,  0.0178, -0.1279,  ...,  0.5

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.0801,  0.0782, -0.4087,  ...,  0.4834, -0.2410,  0.1387],
        [ 0.0931, -0.0629, -0.3599,  ...,  0.6948,  0.0916, -0.2969],
        [-0.0777,  0.1111, -0.7031,  ...,  0.8081,  0.2340,  0.0784],
        [-0.0852, -0.2910, -0.3801,  ...,  0.9038, -0.0454,  0.0979]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Loss after 4540 Batch is nan 
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.0454, -0.2019, -0.2815,  ...,  0.8984,  0.0507, -0.2637],
        [ 0.0886, -0.0051, -0.2834,  ...,  0.8013,  0.0191,  0.2098],
        [ 0.0654,  0.0965, -0.3411,  ...,  0.7324, -0.0865, -0.1276],
        [ 0.4456,  0.1044, -0.4604,  ...,  0.6558,  0.3984,  0.0629]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBa

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.4124,  0.1515, -0.2864,  ...,  0.8130,  0.1699,  0.0278],
        [ 0.2277,  0.0196, -0.1006,  ...,  0.5298, -0.2566,  0.2566],
        [ 0.0247, -0.2214, -0.5225,  ...,  0.5151, -0.3938,  0.0571],
        [-0.1536, -0.0829, -0.2078,  ...,  0.8130, -0.1656,  0.2993]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.4951,  0.3416, -0.4617,  ...,  0.5767, -0.2800, -0.0307],
        [ 0.3538,  0.1721, -0.0142,  ...,  0.7861,  0.3337,  0.0124],
        [ 0.0967, -0.1166, -0.3398,  ...,  0.8760, -0.3792,  0.1385],
        [-0.4851, -0.0931, -0.2512,  ...,  0.6978, -0.2048,  0.0359]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.0638,  0.3491, -0.1558,  ...,  0.6499,  0.2430, -0.0539],
        [ 0.0340, -0.0053, -0.1664,  ...,  0.5537, -0.0972,  0.2622],
        [ 0.1004,  0.1003, -0.2374,  ...,  0.5415,  0.2487,  0.2146],
        [-0.0965,  0.1729, -0.2137,  ...,  0.5015,  0.0726, -0.3064]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.2223, -0.0363, -0.3572,  ...,  0.6919, -0.0910,  0.5562],
        [ 0.3562,  0.0611, -0.5454,  ...,  0.5508, -0.1011,  0.1549],
        [ 0.2839,  0.1046, -0.1776,  ...,  0.7153,  0.1036, -0.0266],
        [ 0.1399, -0.0324, -0.2708,  ...,  0.6860, -0.0882,  0.1254]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.2949,  0.2534, -0.2214,  ...,  0.7979,  0.2490, -0.0235],
        [ 0.0676, -0.1226, -0.0955,  ...,  0.9019, -0.2100,  0.2866],
        [ 0.0329,  0.5093, -0.0279,  ...,  0.4478, -0.0424, -0.1882],
        [ 0.4797, -0.0316, -0.3240,  ...,  0.8042, -0.1852,  0.1870]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.2350,

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.0902,  0.0359, -0.1631,  ...,  0.6514,  0.1719,  0.0714],
        [ 0.2659,  0.3298, -0.1134,  ...,  0.6045, -0.1740, -0.2247],
        [ 0.2747,  0.0306, -0.3232,  ...,  0.1501,  0.1926,  0.0768],
        [ 0.1711,  0.0425, -0.3630,  ...,  1.0303, -0.0119, -0.1180]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.0927,  0.0223, -0.1583,  ...,  1.1631,  0.1161,  0.1694],
        [ 0.2686,  0.0777, -0.5527,  ...,  0.9126,  0.1890,  0.0564],
        [-0.1092,  0.1658, -0.4624,  ...,  0.8291,  0.2272, -0.4099],
        [ 0.3391,  0.1710, -0.3391,  ...,  0.4932, -0.0364,  0.0728]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0

outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[-0.0452, -0.0186, -0.2325,  ...,  0.9248,  0.0836, -0.2751],
        [ 0.0536,  0.3101, -0.2617,  ...,  0.5161,  0.3340,  0.1025],
        [ 0.1063,  0.1759, -0.3054,  ...,  0.5859, -0.2610,  0.1198],
        [ 0.1578, -0.1098, -0.3594,  ...,  0.5122, -0.0467,  0.6353]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.2742,

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[-0.1484, -0.1019, -0.2734,  ...,  0.5688, -0.3472,  0.1644],
        [ 0.1394,  0.4685, -0.2373,  ...,  0.7661,  0.1207,  0.1090],
        [ 0.4150,  0.1025, -0.4175,  ...,  0.5044, -0.1226, -0.0782],
        [ 0.5171,  0.0276, -0.3269,  ...,  1.0449,  0.0609,  0.1793]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.4856, -0.2050, -0.5176,  ...,  1.1436,  0.1714,  0.3682],
        [-0.2471, -0.1190, -0.3103,  ...,  0.7139,  0.1769, -0.0166],
        [-0.0461, -0.0434, -0.2283,  ...,  0.3740, -0.2390, -0.0433],
        [ 0.3774,  0.2112, -0.3198,  ...,  1.0850, -0.0922,  0.2903]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.0572,  0.0179, -0.2125,  ...,  0.6162, -0.4526,  0.1542],
        [ 0.4709,  0.0795, -0.3647,  ...,  0.6523,  0.3394,  0.0293],
        [ 0.5151,  0.2346, -0.2715,  ...,  0.7505,  0.1272,  0.0462],
        [ 0.0367,  0.3213, -0.4941,  ...,  0.8779, -0.1816, -0.0253]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.2727,  0.3303, -0.5044,  ...,  0.7168,  0.0629,  0.1035],
        [-0.0400, -0.1553, -0.1060,  ...,  0.5640, -0.2159,  0.2418],
        [ 0.1556, -0.0235, -0.4993,  ...,  0.7207,  0.0313, -0.1196],
        [ 0.2317,  0.0298, -0.8999,  ...,  0.9937, -0.0415,  0.2147]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.0630,  0.0980, -0.3701,  ...,  0.7363,  0.2156, -0.2639],
        [ 0.0561,  0.1478, -0.3899,  ...,  1.0488, -0.2688, -0.0457],
        [ 0.0971,  0.2556, -0.4043,  ...,  0.7764,  0.0069,  0.0930],
        [ 0.0486,  0.1345, -0.3936,  ...,  0.7930, -0.0837, -0.0250]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.1108, -0.1715, -0.0759,  ...,  0.8550, -0.0698, -0.0229],
        [-0.0903,  0.0344, -0.3806,  ...,  1.0039, -0.2346,  0.0242],
        [ 0.1354, -0.0599, -0.3384,  ...,  0.7446,  0.3757,  0.5601],
        [ 0.2091,  0.2673, -0.3215,  ...,  0.3235, -0.0099, -0.1584]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Loss after 4730 Batch is nan 
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.0438, -0.1412, -0.1659,  ...,  0.3193,  0.0234,  0.2252],
        [ 0.2710,  0.1661, -0.2251,  ...,  0.9043, -0.0767, -0.0189],
        [ 0.1156, -0.2954,  0.2688,  ...,  0.4290,  0.2876,  0.2688],
        [ 0.4712,  0.4385, -0.4746,  ...,  0.8579, -0.0100,  0.0160]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBa

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[-0.0369,  0.0842, -0.6646,  ...,  0.6279,  0.0502,  0.0652],
        [ 0.2690,  0.0703, -0.3389,  ...,  1.0938,  0.0181, -0.0390],
        [ 0.0340, -0.1945, -0.4053,  ...,  0.9409,  0.0040,  0.2764],
        [ 0.0253, -0.2671, -0.5952,  ...,  0.8145, -0.3096,  0.0573]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.0255,  0.0049, -0.2632,  ...,  0.9873, -0.0687,  0.1154],
        [ 0.1193, -0.2343, -0.1985,  ...,  0.2976, -0.2330,  0.4431],
        [-0.2666, -0.0685, -0.2856,  ...,  0.6182, -0.2333,  0.0421],
        [ 0.1247,  0.0458, -0.1520,  ...,  0.7891, -0.1530,  0.0690]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[-0.4224,

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.3984, -0.0843, -0.4617,  ...,  0.4585, -0.0689,  0.5708],
        [ 0.2563,  0.1323, -0.2175,  ...,  0.6211, -0.0536, -0.1718],
        [-0.1197,  0.1074, -0.3665,  ...,  0.7612, -0.2866, -0.0558],
        [ 0.3794, -0.0746, -0.3225,  ...,  0.6021, -0.0012,  0.2367]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.2139, -0.2568, -0.1439,  ...,  0.4941, -0.1192,  0.0908],
        [-0.0890, -0.1882, -0.2622,  ...,  0.6392, -0.1747, -0.1553],
        [ 0.2305,  0.2457, -0.2045,  ...,  1.0244,  0.0220, -0.0450],
        [ 0.2101,  0.3579,  0.1949,  ...,  0.7607,  0.2573, -0.1505]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.0208,  0.2039, -0.4983,  ...,  1.1035, -0.1589,  0.2568],
        [ 0.4080, -0.2450, -0.2158,  ...,  0.5703, -0.0276,  0.2399],
        [-0.0258,  0.2649, -0.0972,  ...,  0.8071, -0.2593, -0.0127],
        [ 0.1169,  0.1184, -0.4233,  ...,  0.5400, -0.1538,  0.2615]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

outputs_teacher tensor([[ 0.0126,  0.1329, -0.1923,  ...,  0.9058, -0.0284,  0.1102],
        [ 0.1482,  0.2993, -0.3901,  ...,  1.0283, -0.0831, -0.0575],
        [ 0.1129,  0.2585, -0.4604,  ...,  0.4885,  0.3469,  0.1537],
        [ 0.3408, -0.1414, -0.0220,  ...,  0.6250,  0.1014,  0.0273]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.2888, -0.4905, -0.2771,  ...,  0.6899,  0.1367,  0.3210],
        [ 0.2012,  0.1544, -0.1625,  ...,  0.5742, -0.0610,  0.0283],
        [ 0.3374,  0.1677, -0.3308,  ...,  0.7544, -0.0769, -0.0364],
        [ 0.4968,  0.1140, -0.5396,  ...,  0.7788,  0.0389,  0.1497]],
     

outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.4924,  0.0041, -0.3716,  ...,  0.9448,  0.0714,  0.0833],
        [ 0.3613,  0.2238, -0.6382,  ...,  0.5840,  0.2061,  0.3713],
        [ 0.4131, -0.0949, -0.1125,  ...,  0.5942,  0.0456,  0.0403],
        [ 0.3911,  0.5444, -0.3462,  ...,  0.4255,  0.0150,  0.0497]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.4180,

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.5112,  0.2457, -0.3489,  ...,  0.7012,  0.0044,  0.1237],
        [ 0.1711,  0.3645, -0.3391,  ...,  1.0020, -0.1444,  0.1146],
        [ 0.2257, -0.1349, -0.2023,  ...,  0.6206,  0.1741, -0.2366],
        [ 0.1901,  0.1384, -0.2942,  ...,  0.5386,  0.0456,  0.4177]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Loss after 4830 Batch is nan 
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBa

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.3865,  0.4248, -0.1250,  ...,  0.6313,  0.2009, -0.1152],
        [ 0.2771, -0.1390, -0.3574,  ...,  0.3630, -0.3752,  0.2751],
        [ 0.1721,  0.4553, -0.1616,  ...,  0.9775, -0.1334,  0.3154],
        [-0.0088,  0.1787, -0.1327,  ...,  0.6934,  0.0190,  0.1014]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.1255,  0.0111, -0.2808,  ...,  0.7871,  0.0061,  0.1410],
        [ 0.0247, -0.1620, -0.1653,  ...,  0.6572, -0.0327,  0.1271],
        [-0.2715,  0.1644, -0.2546,  ...,  0.9243, -0.1032, -0.0442],
        [ 0.4937,  0.1620, -0.4038,  ...,  0.4519, -0.0886, -0.0909]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.1436,  0.0174, -0.4880,  ...,  1.1406, -0.1567, -0.0944],
        [ 0.2771, -0.0492, -0.4617,  ...,  0.6592, -0.4622,  0.0043],
        [ 0.3711,  0.1750, -0.3035,  ...,  0.7769, -0.2357, -0.2771],
        [-0.2517,  0.2040, -0.4438,  ...,  0.6792, -0.2893,  0.0952]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

outputs_teacher tensor([[ 0.1810,  0.3389, -0.2507,  ...,  0.7827,  0.1479, -0.2041],
        [ 0.0864, -0.2788, -0.2363,  ...,  0.7700, -0.3701,  0.2891],
        [-0.1593,  0.0847, -0.2206,  ...,  0.7065, -0.1126,  0.1559],
        [-0.0300, -0.1584, -0.2255,  ...,  0.8271,  0.2314, -0.0811]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.2632,  0.0317, -0.3379,  ...,  0.5332,  0.3040,  0.3638],
        [ 0.1764,  0.2024, -0.4292,  ...,  0.5005, -0.0327, -0.1707],
        [ 0.2094,  0.0198, -0.0661,  ...,  0.5581, -0.0837,  0.0198],
        [-0.0641,  0.1860, -0.3740,  ...,  0.3384,  0.1059,  0.1245]],
     

outputs_teacher tensor([[ 0.2317,  0.2852, -0.2368,  ...,  0.8999,  0.0966, -0.0188],
        [ 0.2344,  0.1628, -0.4451,  ...,  0.3760, -0.2477, -0.0601],
        [ 0.2800,  0.5981, -0.2119,  ...,  0.8037, -0.3130, -0.1043],
        [ 0.2155,  0.3845, -0.3354,  ...,  0.5464, -0.1033, -0.0785]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.0548,  0.3379, -0.1841,  ...,  0.9224,  0.0801, -0.2424],
        [ 0.5566,  0.1263, -0.3601,  ...,  0.5815, -0.1547,  0.1351],
        [ 0.1333,  0.1364, -0.4934,  ...,  0.8467, -0.1985,  0.1807],
        [ 0.3792, -0.1327, -0.3887,  ...,  0.9087, -0.2825, -0.1543]],
     

outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.1254,  0.0899, -0.2583,  ...,  0.7305,  0.3362,  0.1823],
        [ 0.2861,  0.1176, -0.0623,  ...,  0.5225, -0.0444, -0.1410],
        [ 0.1816,  0.2632,  0.0457,  ...,  0.3821,  0.1937,  0.2847],
        [ 0.1543,  0.2661, -0.4814,  ...,  0.7485, -0.0586, -0.0846]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.0595,

outputs_teacher tensor([[ 0.3315,  0.2637, -0.4451,  ...,  0.9478,  0.1190,  0.1545],
        [ 0.1509,  0.1804, -0.1587,  ...,  0.8057,  0.0839,  0.1973],
        [ 0.1543,  0.5439, -0.4417,  ...,  0.6221,  0.0326, -0.0565],
        [ 0.3848,  0.2372, -0.2209,  ...,  1.0342, -0.0948,  0.1710]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.1208, -0.0011, -0.3345,  ...,  0.4890,  0.0606,  0.4312],
        [ 0.3977, -0.3547, -0.0868,  ...,  0.7563,  0.3708,  0.3171],
        [ 0.2136, -0.1992, -0.3962,  ...,  0.4761, -0.4072,  0.4099],
        [ 0.6211, -0.0190, -0.4426,  ...,  0.6924,  0.0078,  0.2109]],
     

outputs_teacher tensor([[ 0.1729,  0.4214, -0.1396,  ...,  0.3528, -0.2271,  0.3977],
        [ 0.1379,  0.2415, -0.1270,  ...,  0.6387, -0.2808,  0.2273],
        [ 0.0109,  0.0873, -0.2954,  ...,  0.5396, -0.2893, -0.0284],
        [ 0.3633,  0.1520, -0.6362,  ...,  0.8750,  0.2693,  0.2286]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.0314, -0.3276, -0.2944,  ...,  0.6865, -0.1332,  0.4077],
        [ 0.0821, -0.1276, -0.2593,  ...,  0.8325,  0.0986,  0.2622],
        [ 0.3042,  0.2357, -0.3201,  ...,  0.4194,  0.2698, -0.1133],
        [ 0.4631,  0.2251, -0.3035,  ...,  0.5176,  0.0836, -0.0540]],
     

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.0651,  0.5776, -0.4404,  ...,  0.5586,  0.2800,  0.0451],
        [ 0.1771, -0.0479, -0.4185,  ...,  0.4824, -0.2340,  0.2217],
        [ 0.2515, -0.0500, -0.3508,  ...,  0.6221, -0.0314,  0.4785],
        [ 0.2166,  0.0205, -0.5225,  ...,  0.5405, -0.0720,  0.1797]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.1053,  0.2466, -0.3198,  ...,  0.8994, -0.0936,  0.1299],
        [ 0.3440,  0.1766, -0.3545,  ...,  0.5439,  0.0629, -0.0134],
        [-0.2284, -0.1273, -0.3494,  ...,  0.8198, -0.1365,  0.0224],
        [ 0.1238,  0.6411, -0.1837,  ...,  0.8604,  0.1075, -0.1187]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[-0.2078, -0.2354, -0.1526,  ...,  1.0693, -0.2830, -0.0455],
        [ 0.2908,  0.0305, -0.4526,  ...,  0.6895, -0.3560,  0.2267],
        [ 0.1594,  0.0849, -0.1493,  ...,  0.5356,  0.0598,  0.1326],
        [ 0.2053,  0.3457, -0.4717,  ...,  1.1289, -0.0071,  0.0595]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.1072,  0.4929, -0.4165,  ...,  0.6426, -0.0456,  0.3269],
        [ 0.4304,  0.2377, -0.3540,  ...,  0.8599,  0.2351,  0.2203],
        [ 0.4028, -0.4734, -0.3628,  ...,  0.4043, -0.2988,  0.2417],
        [ 0.5186,  0.3132, -0.3013,  ...,  0.7710, -0.0977,  0.2856]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.1819,  0.2388, -0.1869,  ...,  0.6396,  0.2285, -0.0996],
        [ 0.2732, -0.1188,  0.0015,  ...,  0.5840,  0.0764, -0.1810],
        [ 0.1631, -0.1766, -0.2300,  ...,  0.6777, -0.3279,  0.2642],
        [ 0.2074, -0.0523, -0.3120,  ...,  0.5874, -0.3213,  0.2380]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.3784,  0.0138, -0.4250,  ...,  1.0752, -0.0866,  0.2905],
        [ 0.1335, -0.3369, -0.4272,  ...,  0.5557, -0.1704,  0.1367],
        [ 0.3884,  0.4065, -0.1781,  ...,  0.7812,  0.0529,  0.0068],
        [ 0.1533,  0.0981, -0.1826,  ...,  0.7827, -0.3518,  0.2927]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.1548,  0.3684, -0.3118,  ...,  0.7310,  0.2517, -0.3208],
        [ 0.2015,  0.5176, -0.1993,  ...,  0.5713, -0.0981, -0.2067],
        [ 0.3860,  0.1913, -0.3157,  ...,  0.8086, -0.1732, -0.2379],
        [ 0.3472, -0.1838, -0.0786,  ...,  0.6987,  0.2426,  0.3167]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[-0.2003, -0.3057, -0.2147,  ...,  0.9067, -0.1292,  0.0344],
        [ 0.1814, -0.0016, -0.2839,  ...,  1.1865,  0.0363,  0.1604],
        [-0.0215,  0.3120, -0.4912,  ...,  0.8022, -0.0632,  0.2421],
        [ 0.4082, -0.0679, -0.1329,  ...,  0.8584,  0.1044,  0.2532]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[-0.1847, -0.0346, -0.2510,  ...,  0.6113, -0.2307, -0.1403],
        [-0.0048,  0.5962, -0.3345,  ...,  0.8032, -0.3430, -0.0635],
        [ 0.2786, -0.0334, -0.5059,  ...,  0.7388, -0.1594,  0.4614],
        [ 0.0472, -0.0717, -0.0790,  ...,  0.8545, -0.2849,  0.1071]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.1769,  0.2720, -0.3555,  ...,  0.7036, -0.1697,  0.2356],
        [-0.1716,  0.0368, -0.4907,  ...,  0.9233, -0.0819,  0.2798],
        [ 0.0727,  0.0061, -0.4841,  ...,  0.9702, -0.2581,  0.1310],
        [ 0.5225,  0.0419, -0.0790,  ...,  0.4900, -0.1620, -0.1278]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.1840,  0.2172, -0.2930,  ...,  0.7271, -0.3250,  0.1741],
        [ 0.3333,  0.6768, -0.4277,  ...,  0.5327,  0.1081,  0.7515],
        [ 0.5166,  0.0091, -0.6411,  ...,  0.7119,  0.0238,  0.2141],
        [ 0.2942, -0.0162, -0.2842,  ...,  0.8184, -0.0336, -0.0546]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[-0.0143,  0.0774, -0.1326,  ...,  0.5894, -0.0845,  0.0091],
        [ 0.0587,  0.1580, -0.0530,  ...,  0.6294,  0.1879,  0.0409],
        [ 0.1705, -0.1566, -0.0568,  ...,  0.4954, -0.1818, -0.0151],
        [ 0.4517,  0.2642, -0.1120,  ...,  0.3989, -0.0911,  0.1545]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.3372, -0.0868, -0.1476,  ...,  0.7163,  0.1165, -0.2217],
        [ 0.0518,  0.1399, -0.1204,  ...,  0.8994, -0.1646,  0.0569],
        [-0.2717, -0.0379, -0.1609,  ...,  0.6987, -0.2280,  0.0756],
        [ 0.3347,  0.1982, -0.2120,  ...,  0.4753,  0.2456, -0.1261]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.4595, -0.2175, -0.0577,  ...,  0.4321,  0.3726,  0.0940],
        [ 0.3181,  0.1976, -0.4573,  ...,  0.3157, -0.1317,  0.4021],
        [ 0.1389,  0.1301, -0.4331,  ...,  0.8389,  0.1295, -0.1854],
        [ 0.3308,  0.3372, -0.2834,  ...,  0.6875, -0.1772, -0.2676]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[-0.1387,  0.1032, -0.2487,  ...,  1.1982, -0.2483,  0.0997],
        [ 0.2534,  0.0671, -0.6797,  ..., -0.3894,  0.3157,  0.3618],
        [ 0.0168,  0.1957, -0.1782,  ...,  0.9702, -0.0206,  0.2068],
        [ 0.3093,  0.2335, -0.2629,  ...,  1.0371, -0.0963, -0.2352]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.4846, -0.5137, -0.1370,  ...,  0.6602,  0.0863,  0.3232],
        [ 0.1729,  0.3149, -0.7700,  ...,  0.8125, -0.0699, -0.0027],
        [ 0.4785,  0.2532, -0.4299,  ...,  0.7437, -0.0978,  0.0480],
        [ 0.0528,  0.4116, -0.2311,  ...,  0.5347, -0.0151,  0.1434]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.5938,  0.4036, -0.2585,  ...,  0.6548, -0.1515, -0.1477],
        [ 0.0366, -0.0955, -0.2369,  ...,  0.5347, -0.1831,  0.4480],
        [ 0.2490,  0.1637, -0.2881,  ...,  0.9326, -0.1431,  0.0609],
        [ 0.1719,  0.3403, -0.4375,  ...,  0.8374, -0.0311,  0.1761]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.1641, -0.4570, -0.3313,  ...,  1.0947, -0.3159, -0.2002],
        [ 0.1654, -0.0912, -0.2734,  ...,  0.9185, -0.1515,  0.1901],
        [ 0.3862, -0.0509, -0.1322,  ...,  0.2766,  0.0858,  0.3350],
        [ 0.0831, -0.1902, -0.3796,  ...,  0.4487, -0.1042,  0.2180]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[-0.1580,  0.2208,  0.0665,  ...,  0.7017, -0.2034,  0.0270],
        [ 0.3799,  0.5454, -0.5879,  ...,  0.4983,  0.2225,  0.3647],
        [ 0.3623, -0.0462, -0.2578,  ...,  0.3899, -0.2268,  0.2046],
        [ 0.2250,  0.2421, -0.4558,  ...,  0.7021, -0.1177, -0.0378]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.2313,  0.1228, -0.4062,  ...,  0.7173,  0.0448,  0.0449],
        [ 0.1750,  0.2347, -0.4250,  ...,  0.7456, -0.2727,  0.2189],
        [ 0.1836,  0.2395, -0.3250,  ...,  0.8486, -0.1093,  0.2061],
        [ 0.3818,  0.0670, -0.1809,  ...,  0.6851,  0.3374,  0.0651]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.1383,  0.0503, -0.3687,  ...,  0.5210, -0.1614, -0.1477],
        [ 0.6216,  0.1903, -0.1460,  ...,  0.6113, -0.0015,  0.1830],
        [ 0.0305,  0.2375, -0.4814,  ...,  1.0244, -0.0125,  0.1843],
        [ 0.2472,  0.2042, -0.3447,  ...,  0.6318, -0.1271,  0.2466]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[-0.0517,  0.0678, -0.3538,  ...,  0.4487, -0.2639, -0.3206],
        [ 0.3743,  0.1166, -0.1708,  ...,  0.7656, -0.1606, -0.2007],
        [ 0.1750, -0.0925, -0.3979,  ...,  0.7080, -0.2267,  0.3435],
        [ 0.2260,  0.1095, -0.4109,  ...,  0.3193,  0.1614,  0.2281]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.5615,  0.2202, -0.6743,  ...,  0.4792,  0.1919,  0.3369],
        [ 0.5732,  0.0996, -0.4021,  ...,  0.7998, -0.1497, -0.0175],
        [ 0.1736,  0.3008, -0.3562,  ...,  1.0957, -0.0606,  0.0259],
        [ 0.6948,  0.2920, -0.3850,  ...,  0.6743,  0.0082,  0.4204]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[-0.3062,  0.3281, -0.1729,  ...,  0.5117, -0.1157,  0.3247],
        [ 0.1981, -0.1133, -0.4543,  ...,  0.6323, -0.2274,  0.0526],
        [ 0.1842,  0.3167, -0.4312,  ...,  0.5708,  0.3376, -0.1065],
        [ 0.2808,  0.6782, -0.3030,  ...,  0.2915,  0.0385,  0.5630]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.7832,  0.1669, -0.0706,  ...,  0.7578,  0.1609,  0.0902],
        [ 0.2428,  0.3921, -0.4419,  ..., -0.2097,  0.0262,  0.1534],
        [ 0.0815, -0.1759, -0.3547,  ...,  0.2085,  0.1084,  0.0306],
        [ 0.1914, -0.2102,  0.0140,  ...,  0.5088, -0.1609,  0.2207]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[-0.0870,  0.1653, -0.1996,  ...,  0.9868, -0.1902,  0.1327],
        [-0.0533,  0.2380, -0.2410,  ...,  0.3369, -0.0337,  0.0014],
        [ 0.4150,  0.0754, -0.3984,  ...,  0.6763,  0.0735,  0.2272],
        [ 0.2917, -0.1242, -0.6060,  ...,  0.3899, -0.0605,  0.6084]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.0768,  0.2423, -0.0467,  ...,  0.8369,  0.1047,  0.1675],
        [ 0.1980,  0.1019, -0.3215,  ...,  0.4165,  0.5366, -0.1142],
        [ 0.2141,  0.1173, -0.4114,  ...,  0.9214, -0.0174,  0.2251],
        [ 0.1516, -0.0318, -0.1798,  ...,  0.2988, -0.0284,  0.1606]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.0878,  0.0802, -0.3105,  ...,  0.7129, -0.1029,  0.1102],
        [ 0.1129, -0.0066, -0.1010,  ...,  0.6118, -0.1428,  0.1674],
        [-0.0703,  0.1742, -0.4446,  ...,  0.7271, -0.1926, -0.2043],
        [ 0.1986,  0.0612, -0.3210,  ...,  0.8564, -0.1866,  0.1136]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.4451,  0.3064, -0.3101,  ...,  0.6646, -0.1405,  0.3459],
        [ 0.1930,  0.0347, -0.3660,  ...,  1.0410, -0.0680,  0.0735],
        [ 0.0487, -0.0136, -0.1604,  ...,  0.7588, -0.0556,  0.0249],
        [ 0.0461,  0.1804, -0.1743,  ...,  1.0020, -0.0835,  0.1117]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.0082, -0.0872, -0.0334,  ...,  0.6841,  0.0138, -0.2876],
        [ 0.1835,  0.3433, -0.5488,  ...,  0.7554, -0.0541, -0.0858],
        [ 0.3013,  0.2756, -0.3594,  ...,  0.5991,  0.1246,  0.1062],
        [ 0.2588,  0.2996, -0.1227,  ...,  0.7949, -0.0903,  0.1499]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.3215,  0.1404, -0.3301,  ...,  0.6040,  0.0156, -0.3015],
        [ 0.5898,  0.1475,  0.0755,  ...,  0.3191,  0.0457,  0.0828],
        [-0.0877,  0.0948, -0.2771,  ...,  0.3391, -0.0906, -0.1198],
        [ 0.1923,  0.1517, -0.1798,  ...,  0.7676,  0.0057, -0.1487]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 6.9946e-02, -4.2511e-02, -2.1881e-02,  ...,  6.5723e-01,
         -4.5410e-01,  1.5845e-01],
        [-2.8052e-01,  6.3181e-06,  2.7908e-02,  ...,  7.0020e-01,
         -2.3547e-01, -9.9792e-02],
        [ 2.4146e-01, -3.8330e-02, -1.4099e-01,  ...,  8.5498e-01,
          6.6956e-02,  1.6626e-01],
        [ 2.0886e-01, -8.9233e-02, -3.1177e-01,  ...,  6.0742e-01,
          1.1200e-01,  3.5229e-01]], device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [n

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.1595,  0.1516, -0.4026,  ...,  0.9224,  0.0249, -0.0696],
        [ 0.4714,  0.4558, -0.6250,  ...,  0.7173, -0.3718,  0.0278],
        [ 0.4736, -0.1143, -0.1637,  ...,  0.9023, -0.0468, -0.3577],
        [-0.1777,  0.1842, -0.3450,  ...,  1.0234, -0.2125,  0.0178]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.2281,  0.1136, -0.3477,  ...,  0.7900, -0.0081, -0.2769],
        [ 0.7153, -0.1141,  0.0695,  ...,  0.2303, -0.3799,  0.2734],
        [-0.1111,  0.1615, -0.0616,  ...,  0.7661, -0.0584, -0.0610],
        [ 0.2494,  0.0012, -0.3186,  ...,  0.7114, -0.1951,  0.2837]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.3469,  0.2141, -0.2180,  ...,  0.6870, -0.0753,  0.1735],
        [ 0.2017,  0.2869, -0.4307,  ...,  1.0908,  0.1949,  0.5796],
        [ 0.3154,  0.0421, -0.1168,  ...,  0.6646, -0.2627, -0.1570],
        [ 0.4512,  0.3660, -0.8066,  ...,  0.5913, -0.1011,  0.2194]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[-0.2460,  0.0135, -0.0999,  ...,  0.7471, -0.1180,  0.2168],
        [-0.0918,  0.3152, -0.2007,  ...,  0.5986, -0.0865,  0.0098],
        [ 0.3328,  0.0296, -0.3091,  ...,  0.5601, -0.0329, -0.0052],
        [ 0.4775,  0.0588, -0.4844,  ...,  0.5898,  0.2515,  0.2247]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.4158,  0.1715, -0.1699,  ...,  0.6948,  0.0742,  0.2102],
        [-0.1880, -0.1022, -0.4509,  ...,  1.0586,  0.1447,  0.0140],
        [ 0.1310,  0.2360, -0.2844,  ...,  0.7129,  0.0537,  0.2793],
        [ 0.0765,  0.1202, -0.0017,  ...,  0.6772, -0.0839,  0.4937]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.6260,

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.7490,  0.3645, -0.2211,  ...,  0.8735, -0.0128,  0.2394],
        [ 0.0923,  0.0064, -0.1191,  ...,  0.8115, -0.0760,  0.0359],
        [ 0.3464,  0.1525,  0.1174,  ...,  0.6440, -0.0122, -0.1161],
        [ 0.0224,  0.3269,  0.0106,  ...,  0.8428,  0.0353, -0.0213]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.2108, -0.0358, -0.2318,  ...,  0.4287,  0.1897,  0.2834],
        [ 0.2815, -0.0514, -0.2654,  ...,  0.5957,  0.1114,  0.0591],
        [-0.0994,  0.0245, -0.3782,  ...,  1.1035, -0.0948,  0.0724],
        [ 0.2505, -0.0905, -0.7852,  ...,  0.5845,  0.2485,  0.3728]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.3110,  0.0688, -0.4194,  ...,  0.7627,  0.0121,  0.0637],
        [ 0.2749,  0.0923, -0.5835,  ...,  0.7485,  0.0163, -0.1912],
        [ 0.3206,  0.2737, -0.0139,  ...,  0.6807, -0.0894,  0.2119],
        [-0.0073,  0.2042, -0.1349,  ...,  0.7280, -0.0421,  0.1931]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.3049, -0.2207, -0.1211,  ...,  0.4243, -0.2551,  0.4004],
        [ 0.0406,  0.1028, -0.5889,  ...,  0.4189,  0.1202,  0.0745],
        [ 0.1818,  0.1395, -0.2203,  ...,  0.5820, -0.1482, -0.0704],
        [ 0.4316,  0.2043, -0.2451,  ...,  0.6118, -0.0368, -0.1719]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.0379,  0.3325, -0.4915,  ...,  0.4973, -0.0016,  0.1621],
        [ 0.4163,  0.1472, -0.4634,  ...,  0.6455, -0.0322,  0.0017],
        [ 0.2588, -0.0441, -0.4253,  ...,  0.5786, -0.1544,  0.0330],
        [ 0.0631, -0.1510, -0.5669,  ...,  0.6553,  0.2878,  0.1611]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.4053, -0.3618, -0.1984,  ...,  0.7104,  0.2495,  0.0497],
        [ 0.7021,  0.2537, -0.2932,  ...,  0.8320,  0.0822,  0.5063],
        [ 0.4978,  0.1268, -0.6201,  ...,  0.6826,  0.1199, -0.0725],
        [ 0.3186, -0.0604, -0.3848,  ...,  0.9766, -0.1831,  0.0070]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.2345, -0.1370, -0.6187,  ...,  0.6411,  0.1733,  0.2800],
        [ 0.5205,  0.0558, -0.0934,  ...,  0.7925,  0.4131, -0.0543],
        [ 0.0945,  0.2944, -0.4617,  ...,  0.8394,  0.2612,  0.3035],
        [ 0.0771,  0.1141, -0.3059,  ...,  0.4077, -0.1342,  0.2505]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

outputs_teacher tensor([[ 0.0840, -0.0925, -0.4043,  ...,  0.9468,  0.1027,  0.0478],
        [-0.1010,  0.0734, -0.5493,  ...,  0.6978, -0.2448,  0.2499],
        [ 0.4180,  0.1091, -0.3682,  ...,  0.9175,  0.2322, -0.0098],
        [ 0.2798, -0.0596, -0.5229,  ...,  0.5850, -0.1026,  0.1440]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.3911,  0.1072, -0.4795,  ...,  0.5410,  0.0035, -0.2106],
        [-0.0264, -0.2445, -0.3906,  ...,  0.7148,  0.0149,  0.1720],
        [ 0.0663,  0.2622, -0.0695,  ...,  0.9292, -0.0760, -0.0195],
        [ 0.4307,  0.2217, -0.4521,  ...,  0.8877, -0.1204, -0.2106]],
     

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.0013,  0.3389, -0.3965,  ...,  0.8198, -0.0806, -0.1516],
        [ 0.2324,  0.0256, -0.2764,  ...,  0.5537, -0.2213,  0.3562],
        [-0.1267, -0.3240, -0.4202,  ...,  0.4226, -0.2642,  0.3079],
        [ 0.0942, -0.2561, -0.0792,  ...,  1.1660, -0.0261,  0.1644]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.4126,  0.2028, -0.4607,  ...,  0.6099, -0.0739, -0.0681],
        [ 0.0951,  0.2047, -0.6587,  ...,  0.7021,  0.1436,  0.3298],
        [ 0.0898, -0.1100, -0.4119,  ...,  0.7598, -0.1544,  0.0919],
        [ 0.4124,  0.2433, -0.3081,  ...,  0.7363,  0.3540,  0.0848]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.7026,  0.0452,  0.0732,  ...,  0.6221,  0.0600, -0.0097],
        [ 0.1823,  0.2539, -0.1809,  ...,  0.8076, -0.3921,  0.2812],
        [ 0.1566, -0.2034, -0.3757,  ...,  0.5063, -0.2134,  0.0358],
        [ 0.2275,  0.1584, -0.2480,  ...,  0.9087,  0.2578, -0.0792]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[-0.0684,  0.1296, -0.4092,  ...,  0.5518,  0.2908,  0.3064],
        [ 0.2335, -0.0746, -0.1906,  ...,  0.3784, -0.0172,  0.1042],
        [ 0.3489,  0.0754, -0.4258,  ...,  0.8486, -0.0336, -0.1099],
        [ 0.4343,  0.0159, -0.2222,  ...,  0.7681,  0.0859, -0.1622]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.1766,  0.4656, -0.3376,  ...,  0.6357, -0.0205,  0.4324],
        [-0.0334,  0.2172, -0.0037,  ...,  0.5967,  0.0626, -0.1016],
        [ 0.2208,  0.1329, -0.5947,  ...,  1.1465,  0.0202, -0.1239],
        [ 0.2639,  0.1803, -0.0441,  ...,  0.3555, -0.3416, -0.0630]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.5259,  0.0940, -0.2159,  ...,  0.7002,  0.2739,  0.0283],
        [-0.0482,  0.1832, -0.3352,  ...,  0.8262, -0.1741, -0.1044],
        [-0.0779, -0.1649, -0.4114,  ...,  0.5669, -0.0332,  0.2825],
        [-0.0538,  0.0442,  0.0273,  ...,  0.8008, -0.1140, -0.0019]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.0761,  0.1836, -0.3992,  ...,  0.7236,  0.0159,  0.1969],
        [ 0.4478,  0.3765, -0.2195,  ...,  0.6890, -0.2590,  0.1669],
        [ 0.3982,  0.0246, -0.5205,  ...,  0.5850, -0.1114,  0.2478],
        [ 0.1500,  0.4321, -0.2917,  ...,  0.6763,  0.2181, -0.0743]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.4229,  0.3599, -0.3767,  ...,  0.7905, -0.0720,  0.0391],
        [-0.0773,  0.2437, -0.4578,  ...,  0.8916, -0.1575, -0.1432],
        [ 0.0480,  0.3752, -0.4087,  ...,  0.6831, -0.4158,  0.4258],
        [ 0.2030,  0.3928, -0.6416,  ...,  0.9639, -0.0300,  0.0729]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.0967,  0.0099, -0.7158,  ...,  0.4297, -0.0619,  0.5898],
        [ 0.0478,  0.3328, -0.5347,  ...,  0.9683, -0.0358,  0.2361],
        [ 0.7432, -0.0147, -0.3816,  ...,  0.7998, -0.2140,  0.2058],
        [ 0.5430,  0.2261, -0.4023,  ...,  0.6826, -0.2295,  0.0074]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.3650, -0.0115, -0.2424,  ...,  0.7686, -0.2117, -0.0032],
        [ 0.0882,  0.0256,  0.0492,  ...,  0.9619,  0.0641, -0.1077],
        [ 0.1671, -0.1530,  0.0150,  ...,  0.4514,  0.2325,  0.2800],
        [ 0.2830,  0.0134, -0.5044,  ...,  0.6982,  0.1094, -0.1254]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.2240,  0.2045, -0.3250,  ...,  0.5063, -0.0234,  0.2573],
        [ 0.1992,  0.3921, -0.2949,  ...,  0.9326,  0.2435,  0.1851],
        [-0.1343,  0.1833, -0.0640,  ...,  1.1514, -0.0521,  0.0323],
        [ 0.2456, -0.2227, -0.4768,  ...,  0.7744, -0.0710, -0.3669]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[-6.8545e-06,  2.2095e-01, -2.6807e-01,  ...,  6.3623e-01,
         -7.4707e-02, -3.4485e-02],
        [ 2.2205e-01, -1.7920e-01, -3.4473e-01,  ...,  8.2275e-01,
         -8.6487e-02,  3.1519e-01],
        [ 8.3313e-03,  8.9340e-03, -3.0762e-01,  ...,  8.9502e-01,
         -2.2705e-01,  1.2842e-01],
        [ 6.9336e-02,  2.5439e-01, -6.2744e-02,  ...,  9.0967e-01,
         -1.0388e-01,  5.2094e-02]], device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [n

outputs_teacher tensor([[ 0.2676,  0.0632, -0.4211,  ...,  0.4158, -0.0190, -0.5356],
        [-0.0588,  0.2068, -0.1117,  ...,  0.8672,  0.1777, -0.2690],
        [ 0.6987, -0.0478, -0.1503,  ...,  0.7544, -0.0061,  0.4653],
        [ 0.0735,  0.2954, -0.2292,  ...,  0.6587, -0.1368,  0.0351]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Loss after 5630 Batch is nan 
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.0555, -0.2147, -0.2244,  ...,  0.8037, -0.0647, -0.1593],
        [-0.0674,  0.0999, -0.3540,  ...,  1.0068, -0.0056, -0.4749],
        [ 0.6255,  0.1547, -0.5371,  ...,  0.4258,  0.1475,  0.0486],
        [ 0.3711,  0.2048, -0.3140,  ...,  0.5

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 2.2363e-01, -4.0649e-02, -3.2349e-01,  ...,  6.9727e-01,
          2.7466e-01,  3.5828e-02],
        [-1.0889e-01,  7.7963e-04, -5.0146e-01,  ...,  9.8340e-01,
          2.5439e-01, -1.3000e-01],
        [ 1.6434e-02,  1.5051e-01, -2.5659e-01,  ...,  8.7451e-01,
          1.3794e-02,  2.5659e-01],
        [ 2.3303e-01,  9.3445e-02, -1.1853e-01,  ...,  7.3389e-01,
          6.8787e-02, -7.1831e-03]], device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [n

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[-0.1085, -0.0576, -0.0844,  ...,  0.7158, -0.4426, -0.0296],
        [ 0.0803,  0.3342, -0.4229,  ...,  1.0615, -0.0716, -0.0668],
        [ 0.4211,  0.2346, -0.1860,  ...,  0.4062, -0.2712,  0.6279],
        [ 0.2451,  0.3982, -0.2581,  ...,  0.7070,  0.2113, -0.0863]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.2371,  0.3984, -0.2639,  ...,  0.4360,  0.1566, -0.1860],
        [ 0.4482,  0.1132, -0.1469,  ...,  0.9595,  0.0154,  0.1875],
        [ 0.0976, -0.1202, -0.4468,  ...,  0.4785,  0.0295,  0.0564],
        [ 0.2461,  0.1348, -0.3108,  ...,  1.0557,  0.2715,  0.4387]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 2.9150e

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.1057, -0.1153, -0.4912,  ...,  0.6841, -0.3232,  0.0411],
        [ 0.2959,  0.1753, -0.1569,  ...,  0.7524, -0.0757, -0.1224],
        [ 0.1960, -0.0339, -0.3577,  ...,  0.5293,  0.1420,  0.1787],
        [ 0.6558, -0.2510, -0.6895,  ...,  0.1247, -0.4080,  0.5259]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[-0.0038, -0.1372, -0.2426,  ...,  0.6816, -0.3484, -0.0391],
        [-0.0046,  0.3108, -0.0617,  ...,  0.9614,  0.2544,  0.3242],
        [ 0.2196,  0.2664, -0.4885,  ...,  0.7222, -0.0453,  0.1260],
        [ 0.1919,  0.2412, -0.5420,  ...,  0.8628,  0.0677,  0.0207]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.4749, -0.3525, -0.4778,  ...,  0.7319,  0.2419,  0.2917],
        [ 0.1390,  0.1492, -0.4243,  ...,  0.9756,  0.2001, -0.0932],
        [ 0.3940,  0.3137, -0.4487,  ...,  0.8584, -0.1635, -0.0524],
        [-0.0619, -0.3909, -0.1670,  ...,  0.3105,  0.2281,  0.5938]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.3950,  0.2451, -0.3857,  ...,  0.7183, -0.0071, -0.0444],
        [ 0.2534, -0.1615, -0.0578,  ...,  0.6987, -0.2378,  0.4929],
        [ 0.1615, -0.1436, -0.3975,  ...,  0.8525,  0.0787,  0.4272],
        [ 0.4170, -0.2227, -0.0011,  ...,  0.3167,  0.0669, -0.1812]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.2593, -0.0607, -0.0479,  ...,  0.5898,  0.5913, -0.1017],
        [ 0.4387,  0.0315, -0.0902,  ...,  0.5679, -0.1088,  0.0053],
        [ 0.4253,  0.0534,  0.0588,  ...,  0.3108, -0.0602,  0.0719],
        [ 0.0074,  0.0564, -0.0776,  ...,  0.8779, -0.1659, -0.1644]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

outputs_teacher tensor([[ 0.3650, -0.2930, -0.2158,  ...,  0.4746, -0.1803, -0.1191],
        [ 0.4753,  0.2939, -0.1189,  ...,  0.6538,  0.0721, -0.1827],
        [ 0.1716,  0.0231, -0.4072,  ...,  0.7373, -0.2142,  0.1594],
        [ 0.1729, -0.0089, -0.3979,  ...,  0.8574, -0.2274,  0.1112]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Loss after 5740 Batch is nan 
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.1265,  0.0754, -0.4482,  ...,  0.7080, -0.1583,  0.2515],
        [ 0.3213,  0.2040, -0.0667,  ...,  0.3457, -0.2632,  0.1624],
        [ 0.0555,  0.0393, -0.1543,  ...,  0.7139,  0.2174, -0.0829],
        [ 0.3357,  0.2515, -0.3784,  ...,  0.6

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.1974,  0.0458, -0.0899,  ...,  0.4045, -0.2705,  0.0809],
        [ 0.2524,  0.1506, -0.1436,  ...,  0.4604,  0.1473, -0.0847],
        [ 0.2908,  0.4026, -0.2966,  ...,  0.6177,  0.2225,  0.0073],
        [ 0.2289,  0.2391, -0.2018,  ...,  0.2037, -0.1990,  0.1437]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.2534,  0.0477, -0.2454,  ...,  0.6533,  0.0233,  0.2266],
        [ 0.1675, -0.1548, -0.4146,  ...,  0.8384,  0.2615, -0.0708],
        [ 0.3823, -0.2683, -0.1644,  ...,  0.7988,  0.0715,  0.0448],
        [ 0.1829,  0.3623, -0.5156,  ...,  0.7061,  0.2927,  0.2255]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Loss after 5770 Batch is nan 
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBa

tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.1918,  0.2125, -0.5420,  ...,  0.8706, -0.0473,  0.1832],
        [ 0.0461, -0.2551, -0.0801,  ...,  0.4963,  0.1429,  0.2251],
        [ 0.1571,  0.2047, -0.2086,  ...,  0.5557,  0.2115, -0.2566],
        [ 0.0872,  0.2471, -0.5234,  ...,  0.7104, -0.0490,  0.3992]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.2976,  0.1433, -0.2032,  ...,  0.7095, -0.1638, -0.0928],
        [ 0.4370,  0.3550,  0.0414,  ...,  1.0791,  0.0432,  0.0627],
        [ 0.0280,  0.3777, -0.2742,  ...,  1.0820,  0.2070, -0.0023],
        [ 0.2820,  0.1448, -0.2510,  ...,  0.2473, -0.0989, -0.0260]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

outputs_teacher tensor([[ 0.1753,  0.1279, -0.2346,  ...,  0.8657, -0.2944,  0.0075],
        [ 0.1550,  0.0751, -0.3821,  ...,  1.0020, -0.0731,  0.0651],
        [ 0.3730, -0.0199, -0.2622,  ...,  0.4434, -0.2803,  0.4585],
        [ 0.0188,  0.2003, -0.2148,  ...,  0.8931,  0.2032,  0.2346]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.3745,  0.1490, -0.2179,  ...,  1.1221,  0.0955,  0.0754],
        [ 0.3464,  0.2012, -0.2771,  ...,  0.7310, -0.0113,  0.0062],
        [ 0.0958, -0.1085, -0.1957,  ...,  0.4370, -0.0837,  0.4695],
        [ 0.0325,  0.1945, -0.3320,  ...,  0.6118, -0.0247, -0.1820]],
     

outputs_teacher tensor([[-0.2642,  0.0964, -0.1486,  ...,  0.4817, -0.1166, -0.0134],
        [-0.0822,  0.0228, -0.2927,  ...,  1.0088, -0.1035,  0.0229],
        [ 0.2465,  0.4099, -0.5112,  ...,  1.0205,  0.1965, -0.1809],
        [ 0.3452, -0.0099, -0.2717,  ...,  0.9990, -0.0632,  0.1387]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.2368,  0.4421, -0.4014,  ...,  0.6353,  0.2690, -0.3350],
        [ 0.3916,  0.2817, -0.2391,  ...,  0.7744, -0.1742,  0.2798],
        [ 0.3142, -0.3206, -0.4080,  ...,  0.9375,  0.1058, -0.0769],
        [-0.0881,  0.1920, -0.1954,  ...,  0.9473, -0.0181,  0.1825]],
     

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.2168,  0.1891, -0.2054,  ...,  0.5908, -0.1158,  0.2593],
        [ 0.5273, -0.2040, -0.3718,  ...,  0.4319, -0.1630,  0.3894],
        [ 0.2825,  0.4329, -0.1532,  ...,  0.2239, -0.1299,  0.5781],
        [ 0.4104,  0.0030, -0.1700,  ...,  0.8027, -0.0092, -0.2917]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.4238, -0.0372, -0.1881,  ...,  0.7134, -0.1562,  0.0298],
        [ 0.0386, -0.1565, -0.3740,  ...,  0.6694, -0.0619,  0.3164],
        [ 0.2252,  0.2391, -0.3696,  ...,  0.7622,  0.2107,  0.0156],
        [ 0.3149,  0.3960, -0.2328,  ...,  0.9482,  0.1273, -0.1134]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.2346,  0.2769, -0.3718,  ...,  0.9478, -0.0442, -0.2632],
        [ 0.0670, -0.2097, -0.2595,  ...,  0.8672,  0.0051,  0.2362],
        [ 0.2230,  0.1948, -0.2849,  ...,  0.9287,  0.0906, -0.0390],
        [ 0.0201,  0.2124, -0.1967,  ...,  0.5566, -0.2700,  0.1564]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.1250,  0.0776,  0.0873,  ...,  0.4866, -0.3516, -0.0508],
        [ 0.2874, -0.0729, -0.2386,  ...,  0.3196,  0.3044,  0.4648],
        [ 0.1689, -0.3943, -0.2277,  ...,  0.3735, -0.2571, -0.0515],
        [ 0.4224, -0.2164, -0.2761,  ...,  0.3252,  0.0023,  0.3982]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

outputs_teacher tensor([[ 0.2156, -0.0916, -0.0336,  ...,  0.3101, -0.2969, -0.1361],
        [-0.0169, -0.1064, -0.2223,  ...,  0.4526, -0.2883, -0.0597],
        [-0.0725,  0.5620, -0.2959,  ...,  0.8296, -0.2096,  0.1175],
        [ 0.3301,  0.3079, -0.5029,  ...,  0.7573, -0.3579,  0.2323]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.2537,  0.2450, -0.4802,  ...,  1.0889, -0.3035,  0.2490],
        [-0.0255,  0.1675, -0.3047,  ...,  0.7144,  0.0054,  0.2771],
        [-0.0393,  0.3538, -0.0561,  ...,  0.4724,  0.0485,  0.0133],
        [ 0.1913,  0.1516, -0.5303,  ...,  0.3853, -0.1786,  0.4236]],
     

outputs_teacher tensor([[ 0.1829,  0.4795,  0.0191,  ...,  0.3904,  0.3286, -0.2607],
        [ 0.1072,  0.0558, -0.2207,  ...,  1.0088,  0.0227,  0.2205],
        [ 0.2148,  0.0763, -0.1953,  ...,  0.5786, -0.3479,  0.2037],
        [ 0.0105,  0.3252,  0.0368,  ...,  0.8262, -0.0837, -0.0549]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.0413, -0.1060, -0.2496,  ...,  0.7432, -0.0029,  0.0816],
        [ 0.3096, -0.0909, -0.1361,  ...,  0.7100,  0.2734,  0.4893],
        [ 0.3411,  0.2042, -0.4912,  ...,  0.7905, -0.2184, -0.1389],
        [ 0.3000,  0.1503, -0.4016,  ...,  0.4983, -0.0338, -0.0583]],
     

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.1642,  0.1080, -0.1786,  ...,  0.5732,  0.2229,  0.0298],
        [ 0.0402, -0.0987, -0.2922,  ...,  0.3870, -0.3354,  0.1698],
        [ 0.1581, -0.0706, -0.4719,  ...,  0.7915,  0.1422,  0.1862],
        [-0.6162,  0.5225, -0.4414,  ...,  0.4597,  0.0023,  0.1881]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.2253, -0.0900, -0.3933,  ...,  0.7637, -0.0045,  0.3430],
        [ 0.5112,  0.1731, -0.4060,  ...,  0.5093,  0.0408,  0.3599],
        [ 0.0968,  0.0150, -0.4302,  ...,  0.7397,  0.0445,  0.2250],
        [ 0.3154,  0.2004, -0.4126,  ...,  0.0532,  0.0043,  0.6812]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0

tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Loss after 5940 Batch is nan 
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.1896,  0.3767, -0.6099,  ...,  0.8027, -0.0586, -0.0105],
        [ 0.1936,  0.0490, -0.3931,  ...,  0.5815,  0.0944,  0.1831],
        [ 0.2190,  0.2473, -0.1498,  ...,  0.2771,  0.0662,  0.0271],
        [ 0.3625,  0.1373, -0.1567,  ...,  0.8042, -0.1152,  0.0612]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., n

tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[-2.8662e-01,  4.5068e-01, -7.1350e-02,  ...,  9.5459e-01,
          8.3847e-03,  8.4473e-02],
        [ 6.8298e-02, -9.3508e-04, -5.6885e-01,  ...,  7.8125e-01,
         -1.1517e-01, -1.9363e-02],
        [ 3.5620e-01,  2.7390e-02, -3.9673e-01,  ...,  4.2603e-01,
         -3.3740e-01, -2.5024e-01],
        [ 4.5874e-01,  1.9263e-01, -3.9014e-01,  ...,  5.0146e-01,
         -1.1002e-02,  4.7363e-01]], device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ...

outputs_teacher tensor([[ 0.1935, -0.0610, -0.5283,  ...,  0.4915, -0.2615,  0.0932],
        [ 0.0365,  0.3535,  0.3496,  ...,  0.8628, -0.0157, -0.2393],
        [ 0.2563,  0.3733, -0.3362,  ...,  0.4333,  0.1670,  0.1907],
        [-0.1807, -0.1791, -0.1602,  ...,  0.6353, -0.3171,  0.5225]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.2202,  0.0950, -0.4043,  ...,  0.6021, -0.2000, -0.0838],
        [ 0.1868,  0.3616, -0.3701,  ...,  0.5176, -0.0260,  0.2368],
        [ 0.1006,  0.0817, -0.1226,  ...,  0.6055, -0.0945,  0.0071],
        [ 0.1868, -0.1046, -0.5098,  ...,  0.0539,  0.2416, -0.0141]],
     

outputs_teacher tensor([[ 0.5933,  0.2269, -0.3718,  ...,  0.8013, -0.1957,  0.1426],
        [ 0.1708,  0.0189, -0.2559,  ...,  0.8667, -0.0668,  0.0067],
        [ 0.2568,  0.6528, -0.3293,  ...,  1.0596,  0.0852, -0.2269],
        [ 0.4653,  0.0818, -0.4307,  ...,  0.3811,  0.0772,  0.1979]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.2279, -0.1578, -0.0422,  ...,  0.5308, -0.3330,  0.0637],
        [ 0.5771,  0.2416, -0.2236,  ...,  0.5986, -0.0609, -0.2537],
        [ 0.1034, -0.0959, -0.6167,  ...,  0.2695, -0.0024,  0.5034],
        [ 0.1296, -0.4722, -0.2649,  ...,  0.4319, -0.1399,  0.0968]],
     

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.2228, -0.0446, -0.3091,  ...,  0.8438,  0.2251, -0.0308],
        [ 0.2094,  0.1159, -0.2769,  ...,  0.7471, -0.0734,  0.0424],
        [ 0.1429, -0.3987, -0.3105,  ...,  0.4565,  0.0394,  0.0337],
        [ 0.1059,  0.4385, -0.3293,  ...,  0.3154, -0.0553,  0.2051]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Loss after 5990 Batch is nan 
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBa

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.4231,  0.1562, -0.3660,  ...,  0.5317, -0.0584,  0.3328],
        [ 0.2119,  0.2245, -0.4927,  ...,  0.9150,  0.0121, -0.0107],
        [ 0.1188,  0.1843, -0.3110,  ...,  1.0117, -0.0176,  0.0073],
        [ 0.2255,  0.1945, -0.3206,  ...,  0.5908,  0.3103, -0.1115]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.3569,  0.2786, -0.1859,  ...,  1.0439, -0.3696,  0.0157],
        [ 0.2209,  0.3337, -0.5005,  ...,  0.6694, -0.6333, -0.3862],
        [ 0.0858,  0.2333, -0.5952,  ...,  0.5811,  0.0947,  0.1010],
        [ 0.3633, -0.1058, -0.2927,  ...,  0.5547,  0.0036, -0.2321]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.1879, -0.1599, -0.4128,  ...,  0.8652, -0.0251,  0.0775],
        [ 0.2949,  0.2102, -0.3420,  ...,  0.9067,  0.0662,  0.0200],
        [ 0.5088,  0.2477, -0.0179,  ...,  0.5796, -0.0648,  0.1818],
        [ 0.1672,  0.5161, -0.2325,  ...,  0.6211, -0.1770, -0.1877]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Loss after 6030 Batch is nan 
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBa

Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tensor([[ 0.0703,  0.0454, -0.2429,  ...,  0.7446, -0.1526, -0.0285],
        [-0.5117,  0.0696, -0.0864,  ...,  0.6260, -0.1949, -0.2064],
        [ 0.2661,  0.3574, -0.0474,  ...,  0.7759, -0.2168, -0.1567],
        [-0.1626, -0.0196, -0.5337,  ...,  1.0176, -0.1074, -0.0839]],
       device='cuda:0', dtype=torch.float16)
tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>)
Before Loss 
outputs_student tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16, grad_fn=<MmBackward0>)
outputs_teacher tens

Exception ignored in: <bound method _MultiProcessingDataLoaderIter.__del__ of <torch.utils.data.dataloader._MultiProcessingDataLoaderIter object at 0x7fec19b7bac8>>
Traceback (most recent call last):
  File "/home/ecbm4040/envTF24/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 1328, in __del__
    self._shutdown_workers()
  File "/home/ecbm4040/envTF24/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 1301, in _shutdown_workers
    w.join(timeout=_utils.MP_STATUS_CHECK_INTERVAL)
  File "/usr/lib/python3.6/multiprocessing/process.py", line 124, in join
    res = self._popen.wait(timeout)
  File "/usr/lib/python3.6/multiprocessing/popen_fork.py", line 47, in wait
    if not wait([self.sentinel], timeout):
  File "/usr/lib/python3.6/multiprocessing/connection.py", line 911, in wait
    ready = selector.select(timeout)
  File "/usr/lib/python3.6/selectors.py", line 376, in select
    fd_event_list = self._poll.poll(timeout)
KeyboardInterrupt: 


KeyboardInterrupt: 

In [None]:
for x in train_dataloader:
    print(x[0].shape,x[1].shape)
    break

In [None]:
# torch.save(
#     {
#         "distilbert": student_distil_bert.state_dict(),
#     },
#     f"distiled_distilbert_sst2_{datetime.now():%Y-%m-%d_%H-%M-%S%z}.pt",
# )