In [5]:
if torch.backends.mps.is_available():
    device = torch.device("mps")  # Apple Metal
elif torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")

print(device)

mps


In [6]:
os.environ["PYTORCH_MPS_HIGH_WATERMARK_RATIO"] = "0.0"
# torch.backends.mps.enable_flash_sdp(False)


In [4]:
from diffusers import StableDiffusionPipeline
from transformers import AutoTokenizer
import os
import torch
from torch import nn
from torch.optim import AdamW
from torch.utils.data import DataLoader
from diffusers import StableDiffusionPipeline, UNet2DConditionModel, DDPMScheduler, LMSDiscreteScheduler
from transformers import CLIPTextModel, CLIPTokenizer
import clip
from tqdm import tqdm
from scripts import eng_processor
from scripts.text_preprocessing import TextImageDataset
from scripts.train_model import BertEncoder, Generator, Discriminator
from scripts.inference import generate_image_from_text
from scripts.utils import combine_dataset, Evaluator
import pandas as pd
from collections import defaultdict
from sklearn.model_selection import train_test_split
from torchvision import transforms


  from .autonotebook import tqdm as notebook_tqdm


In [35]:
combined_df = pd.read_csv("datasets/combined_dataset.csv")
combined_df.tail(5)

Unnamed: 0,image_path,caption,source
90483,datasets/coco_dataset/train2017/000000370808.jpg,A black cat looking out the window at a black ...,MSCOCO
90484,datasets/coco_dataset/train2017/000000370808.jpg,Black cat sitting on window ledge looking outs...,MSCOCO
90485,datasets/coco_dataset/train2017/000000370808.jpg,A black cat looks out the window as a crow out...,MSCOCO
90486,datasets/coco_dataset/train2017/000000370808.jpg,A cat by a window with a small bird outside.,MSCOCO
90487,datasets/coco_dataset/train2017/000000370808.jpg,A cat watches a bird through a window.,MSCOCO


In [36]:
# combined_df= combined_df.where(combined_df["source"]=="Flickr")

# combined_df = combined_df.dropna()
combined_df = combined_df.sample(frac=1, random_state=42).head(10000)
combined_df = combined_df.reset_index(drop=True)  # Reset index to sequential integers

for i in range(len(combined_df)):
    combined_df.caption[i] = eng_processor.main(combined_df.caption[i])

You are setting values through chained assignment. Currently this works in certain cases, but when using Copy-on-Write (which will become the default behaviour in pandas 3.0) this will never work to update the original DataFrame or Series, because the intermediate object on which we are setting values will behave as a copy.
A typical example is when you are setting values in a column of a DataFrame, like:

df["col"][row_indexer] = value

Use `df.loc[row_indexer, "col"] = values` instead, to perform the assignment in a single step and ensure this keeps updating the original `df`.

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy

  combined_df.caption[i] = eng_processor.main(combined_df.caption[i])


In [37]:
image_to_captions = defaultdict(list)

for _, row in combined_df.iterrows():
    img_path = row['image_path']
    image_to_captions[img_path].append(row['caption'])

In [38]:
image_paths = list(image_to_captions.keys())
train_ids, test_ids = train_test_split(image_paths, test_size=0.2, random_state=42)


In [39]:
# image transform
transform = transforms.Compose([
    transforms.Resize((256, 256)),  # Resize the image
    transforms.ToTensor(),  # Convert PIL Image to tensor
    transforms.Lambda(lambda x: (x * 2) - 1)  # Scale from [0,1] to [-1,1]
])


train_dataset = TextImageDataset(
    image_to_captions=image_to_captions,
    image_paths=train_ids,  
    transform=transform
)

test_dataset = TextImageDataset(
    image_to_captions=image_to_captions,
    image_paths=test_ids,  
    transform=transform
)

# Create dataloaders
train_loader = DataLoader(train_dataset, batch_size=2, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=2, shuffle=False)

Dataset size: 7988 (all captions included)
Dataset size: 2011 (all captions included)


In [13]:
clip_tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
clip_text_encoder = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14").to(device)



In [14]:
clip_text_encoder.eval()


CLIPTextModel(
  (text_model): CLIPTextTransformer(
    (embeddings): CLIPTextEmbeddings(
      (token_embedding): Embedding(49408, 768)
      (position_embedding): Embedding(77, 768)
    )
    (encoder): CLIPEncoder(
      (layers): ModuleList(
        (0-11): 12 x CLIPEncoderLayer(
          (self_attn): CLIPSdpaAttention(
            (k_proj): Linear(in_features=768, out_features=768, bias=True)
            (v_proj): Linear(in_features=768, out_features=768, bias=True)
            (q_proj): Linear(in_features=768, out_features=768, bias=True)
            (out_proj): Linear(in_features=768, out_features=768, bias=True)
          )
          (layer_norm1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (mlp): CLIPMLP(
            (activation_fn): QuickGELUActivation()
            (fc1): Linear(in_features=768, out_features=3072, bias=True)
            (fc2): Linear(in_features=3072, out_features=768, bias=True)
          )
          (layer_norm2): LayerNorm((768,), ep

In [16]:
# Load the pretrained pipeline
pipeline = StableDiffusionPipeline.from_pretrained(
    "runwayml/stable-diffusion-v1-5",  # or any other model
    torch_dtype=torch.float16,         # Use float16 only if GPU/MLCompute is supported
)

pipeline.to("mps")  # Use M1/M2 GPU via Metal backend

Cannot initialize model with low cpu memory usage because `accelerate` was not found in the environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install `accelerate` for faster and less memory-intense model loading. You can do so with: 
```
pip install accelerate
```
.
Loading pipeline components...: 100%|█████████████| 7/7 [00:07<00:00,  1.02s/it]


StableDiffusionPipeline {
  "_class_name": "StableDiffusionPipeline",
  "_diffusers_version": "0.33.1",
  "_name_or_path": "runwayml/stable-diffusion-v1-5",
  "feature_extractor": [
    "transformers",
    "CLIPImageProcessor"
  ],
  "image_encoder": [
    null,
    null
  ],
  "requires_safety_checker": true,
  "safety_checker": [
    "stable_diffusion",
    "StableDiffusionSafetyChecker"
  ],
  "scheduler": [
    "diffusers",
    "PNDMScheduler"
  ],
  "text_encoder": [
    "transformers",
    "CLIPTextModel"
  ],
  "tokenizer": [
    "transformers",
    "CLIPTokenizer"
  ],
  "unet": [
    "diffusers",
    "UNet2DConditionModel"
  ],
  "vae": [
    "diffusers",
    "AutoencoderKL"
  ]
}

In [8]:
class Projection(nn.Module):
    def __init__(self, in_dim=768, out_dim=768):
        super().__init__()
        self.proj = nn.Sequential(
            nn.Linear(in_dim, out_dim),
            nn.Tanh()
        )

    def forward(self, x):
        return self.proj(x)


In [16]:

projection = Projection().to(device)

# === 4. Stable Diffusion Decoder ===
vae = pipeline.vae.to(device)
unet = pipeline.unet.to(device)
scheduler = LMSDiscreteScheduler.from_config(pipeline.scheduler.config)



In [17]:
# Freeze everything except projection
for param in vae.parameters(): param.requires_grad = False
for param in unet.parameters(): param.requires_grad = False

optimizer = torch.optim.Adam(projection.parameters(), lr=1e-4)
loss_fn = nn.MSELoss()


In [18]:
for epoch in range(10):

    unet.train()
    clip_text_encoder.eval()
    vae.eval()
    projection.train()


    for i, batch in enumerate(tqdm(train_loader, desc=f"Training Epoch {epoch+1}")):
        images = batch["image"].to(device)
        texts = batch["caption_text"]
        
        # Match input dtype with model
        dtype = next(vae.parameters()).dtype
        images = images.to(dtype=dtype)

        with torch.no_grad():
            # Text encoding
            inputs = clip_tokenizer(texts, return_tensors="pt", padding=True, truncation=True).to(device)
            inputs = {key: val.to(device) for key, val in inputs.items()}

            text_feats = clip_text_encoder(**inputs).last_hidden_state  # (B, seq_len, 768)


            # get latent for vae
            latents = vae.encode(images).latent_dist.sample() * 0.18215


        pooled_feats = text_feats.mean(dim=1)  # (B, 768)
        projected_feats = projection(pooled_feats)

        # add noise to it
        with torch.no_grad():

            noise = torch.randn_like(latents)
            timesteps = torch.randint(0, 1000, (latents.shape[0],), device=device).long()
            noisy_latents = scheduler.add_noise(latents, noise, timesteps)

        projected_feats = projected_feats.to(dtype=noisy_latents.dtype).unsqueeze(1)  # (B, 1, D)

        noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states=projected_feats).sample

        loss = loss_fn(noise_pred, noise)

        loss.backward()
        optimizer.step()
        optimizer.zero_grad()


        if i % 100 == 0:
            print(f"[Batch {i}/{len(train_loader)}] [Loss: {loss.item():.4f}]")

        # Clear memory
        del images, texts, inputs, text_feats, pooled_feats, projected_feats
        del latents, noise, timesteps, noisy_latents, noise_pred
        torch.mps.empty_cache()

    print(f"Epoch {epoch} finished")


# Save model after training
torch.save({
    'unet': unet.state_dict(),
    'clip_text_encoder': clip_text_encoder.state_dict(),
    'projection': projection.state_dict()
}, "final_model_checkpoint.pt")

print("Training completed ✅")

Training Epoch 1:   0%|                      | 1/4016 [00:01<1:23:43,  1.25s/it]

[Batch 0/4016] [Loss: 0.3232]


Training Epoch 1:   3%|▌                     | 101/4016 [00:40<25:30,  2.56it/s]

[Batch 100/4016] [Loss: 0.3823]


Training Epoch 1:   5%|█                     | 201/4016 [01:19<25:03,  2.54it/s]

[Batch 200/4016] [Loss: 0.3982]


Training Epoch 1:   7%|█▋                    | 301/4016 [02:01<26:38,  2.32it/s]

[Batch 300/4016] [Loss: 0.3384]


Training Epoch 1:  10%|██▏                   | 401/4016 [02:40<23:27,  2.57it/s]

[Batch 400/4016] [Loss: 0.3679]


Training Epoch 1:  12%|██▋                   | 501/4016 [03:20<24:02,  2.44it/s]

[Batch 500/4016] [Loss: 0.4478]


Training Epoch 1:  15%|███▎                  | 601/4016 [03:59<21:57,  2.59it/s]

[Batch 600/4016] [Loss: 0.3013]


Training Epoch 1:  17%|███▊                  | 701/4016 [04:38<22:00,  2.51it/s]

[Batch 700/4016] [Loss: 0.4221]


Training Epoch 1:  20%|████▍                 | 801/4016 [05:18<20:42,  2.59it/s]

[Batch 800/4016] [Loss: 0.4031]


Training Epoch 1:  22%|████▉                 | 901/4016 [05:57<19:57,  2.60it/s]

[Batch 900/4016] [Loss: 0.5649]


Training Epoch 1:  25%|█████▏               | 1001/4016 [06:36<20:12,  2.49it/s]

[Batch 1000/4016] [Loss: 0.2632]


Training Epoch 1:  27%|█████▊               | 1101/4016 [07:15<19:02,  2.55it/s]

[Batch 1100/4016] [Loss: 0.3904]


Training Epoch 1:  30%|██████▎              | 1201/4016 [07:55<17:54,  2.62it/s]

[Batch 1200/4016] [Loss: 0.3252]


Training Epoch 1:  32%|██████▊              | 1301/4016 [08:34<17:43,  2.55it/s]

[Batch 1300/4016] [Loss: 0.2568]


Training Epoch 1:  35%|███████▎             | 1401/4016 [09:13<17:29,  2.49it/s]

[Batch 1400/4016] [Loss: 0.3064]


Training Epoch 1:  37%|███████▊             | 1501/4016 [09:53<16:23,  2.56it/s]

[Batch 1500/4016] [Loss: 0.3442]


Training Epoch 1:  40%|████████▎            | 1601/4016 [10:32<16:00,  2.52it/s]

[Batch 1600/4016] [Loss: 0.3782]


Training Epoch 1:  42%|████████▉            | 1701/4016 [11:11<14:51,  2.60it/s]

[Batch 1700/4016] [Loss: 0.3552]


Training Epoch 1:  45%|█████████▍           | 1801/4016 [11:50<14:26,  2.56it/s]

[Batch 1800/4016] [Loss: 0.3550]


Training Epoch 1:  47%|█████████▉           | 1901/4016 [12:30<13:34,  2.60it/s]

[Batch 1900/4016] [Loss: 0.5303]


Training Epoch 1:  50%|██████████▍          | 2001/4016 [13:09<13:15,  2.53it/s]

[Batch 2000/4016] [Loss: 0.3474]


Training Epoch 1:  52%|██████████▉          | 2101/4016 [13:48<12:22,  2.58it/s]

[Batch 2100/4016] [Loss: 0.3489]


Training Epoch 1:  55%|███████████▌         | 2201/4016 [14:28<11:58,  2.53it/s]

[Batch 2200/4016] [Loss: 0.2708]


Training Epoch 1:  57%|████████████         | 2301/4016 [15:07<11:02,  2.59it/s]

[Batch 2300/4016] [Loss: 0.4895]


Training Epoch 1:  60%|████████████▌        | 2401/4016 [15:46<10:32,  2.55it/s]

[Batch 2400/4016] [Loss: 0.2231]


Training Epoch 1:  62%|█████████████        | 2501/4016 [16:25<09:52,  2.56it/s]

[Batch 2500/4016] [Loss: 0.4858]


Training Epoch 1:  65%|█████████████▌       | 2601/4016 [17:04<09:15,  2.55it/s]

[Batch 2600/4016] [Loss: 0.3733]


Training Epoch 1:  67%|██████████████       | 2701/4016 [17:44<08:44,  2.51it/s]

[Batch 2700/4016] [Loss: 0.1864]


Training Epoch 1:  70%|██████████████▋      | 2801/4016 [18:23<07:53,  2.56it/s]

[Batch 2800/4016] [Loss: 0.4731]


Training Epoch 1:  72%|███████████████▏     | 2901/4016 [19:02<07:17,  2.55it/s]

[Batch 2900/4016] [Loss: 0.4316]


Training Epoch 1:  75%|███████████████▋     | 3001/4016 [19:42<06:37,  2.55it/s]

[Batch 3000/4016] [Loss: 0.3027]


Training Epoch 1:  77%|████████████████▏    | 3101/4016 [20:21<05:57,  2.56it/s]

[Batch 3100/4016] [Loss: 0.3328]


Training Epoch 1:  80%|████████████████▋    | 3201/4016 [21:00<05:32,  2.45it/s]

[Batch 3200/4016] [Loss: 0.3018]


Training Epoch 1:  82%|█████████████████▎   | 3301/4016 [21:39<04:38,  2.56it/s]

[Batch 3300/4016] [Loss: 0.3267]


Training Epoch 1:  85%|█████████████████▊   | 3401/4016 [22:19<04:05,  2.51it/s]

[Batch 3400/4016] [Loss: 0.2256]


Training Epoch 1:  87%|██████████████████▎  | 3501/4016 [22:58<03:21,  2.56it/s]

[Batch 3500/4016] [Loss: 0.4500]


Training Epoch 1:  90%|██████████████████▊  | 3601/4016 [23:37<02:43,  2.54it/s]

[Batch 3600/4016] [Loss: 0.3853]


Training Epoch 1:  92%|███████████████████▎ | 3701/4016 [24:16<02:02,  2.57it/s]

[Batch 3700/4016] [Loss: 0.2593]


Training Epoch 1:  95%|███████████████████▉ | 3801/4016 [25:05<02:24,  1.49it/s]

[Batch 3800/4016] [Loss: 0.3101]


Training Epoch 1:  97%|████████████████████▍| 3901/4016 [26:13<01:16,  1.51it/s]

[Batch 3900/4016] [Loss: 0.3748]


Training Epoch 1: 100%|████████████████████▉| 4001/4016 [27:18<00:09,  1.53it/s]

[Batch 4000/4016] [Loss: 0.3196]


Training Epoch 1: 100%|█████████████████████| 4016/4016 [27:28<00:00,  2.44it/s]


Epoch 0 finished


Training Epoch 2:   0%|                        | 1/4016 [00:00<43:03,  1.55it/s]

[Batch 0/4016] [Loss: 0.4053]


Training Epoch 2:   3%|▌                     | 101/4016 [01:06<42:21,  1.54it/s]

[Batch 100/4016] [Loss: 0.2140]


Training Epoch 2:   5%|█                     | 201/4016 [02:11<42:17,  1.50it/s]

[Batch 200/4016] [Loss: 0.3354]


Training Epoch 2:   7%|█▋                    | 301/4016 [03:17<40:12,  1.54it/s]

[Batch 300/4016] [Loss: 0.1713]


Training Epoch 2:  10%|██▏                   | 401/4016 [22:08<42:07,  1.43it/s]

[Batch 400/4016] [Loss: 0.3389]


Training Epoch 2:  12%|██▋                   | 501/4016 [23:16<41:09,  1.42it/s]

[Batch 500/4016] [Loss: 0.3582]


Training Epoch 2:  15%|███▎                  | 601/4016 [24:29<44:26,  1.28it/s]

[Batch 600/4016] [Loss: 0.2600]


Training Epoch 2:  17%|███▊                  | 701/4016 [25:43<37:33,  1.47it/s]

[Batch 700/4016] [Loss: 0.3657]


Training Epoch 2:  20%|████▍                 | 801/4016 [26:51<26:47,  2.00it/s]

[Batch 800/4016] [Loss: 0.5713]


Training Epoch 2:  22%|████▉                 | 901/4016 [27:31<20:32,  2.53it/s]

[Batch 900/4016] [Loss: 0.3525]


Training Epoch 2:  25%|█████▏               | 1001/4016 [28:11<19:33,  2.57it/s]

[Batch 1000/4016] [Loss: 0.2876]


Training Epoch 2:  27%|█████▊               | 1101/4016 [28:53<21:27,  2.26it/s]

[Batch 1100/4016] [Loss: 0.1978]


Training Epoch 2:  30%|██████▎              | 1201/4016 [29:37<22:10,  2.12it/s]

[Batch 1200/4016] [Loss: 0.5640]


Training Epoch 2:  32%|██████▊              | 1301/4016 [30:19<17:22,  2.60it/s]

[Batch 1300/4016] [Loss: 0.2062]


Training Epoch 2:  35%|███████▎             | 1401/4016 [30:58<17:07,  2.55it/s]

[Batch 1400/4016] [Loss: 0.2712]


Training Epoch 2:  37%|███████▊             | 1501/4016 [31:39<17:27,  2.40it/s]

[Batch 1500/4016] [Loss: 0.3359]


Training Epoch 2:  40%|████████▎            | 1601/4016 [32:19<16:10,  2.49it/s]

[Batch 1600/4016] [Loss: 0.3535]


Training Epoch 2:  42%|████████▉            | 1701/4016 [33:00<16:07,  2.39it/s]

[Batch 1700/4016] [Loss: 0.2106]


Training Epoch 2:  45%|█████████▍           | 1801/4016 [33:40<14:16,  2.59it/s]

[Batch 1800/4016] [Loss: 0.2549]


Training Epoch 2:  47%|█████████▉           | 1901/4016 [34:22<15:41,  2.25it/s]

[Batch 1900/4016] [Loss: 0.3665]


Training Epoch 2:  50%|██████████▍          | 2001/4016 [35:02<13:14,  2.54it/s]

[Batch 2000/4016] [Loss: 0.3787]


Training Epoch 2:  52%|██████████▉          | 2101/4016 [35:41<12:27,  2.56it/s]

[Batch 2100/4016] [Loss: 0.1720]


Training Epoch 2:  55%|███████████▌         | 2201/4016 [36:24<12:05,  2.50it/s]

[Batch 2200/4016] [Loss: 0.6113]


Training Epoch 2:  57%|████████████         | 2301/4016 [37:04<10:57,  2.61it/s]

[Batch 2300/4016] [Loss: 0.3308]


Training Epoch 2:  60%|████████████▌        | 2401/4016 [37:43<10:29,  2.57it/s]

[Batch 2400/4016] [Loss: 0.4697]


Training Epoch 2:  62%|█████████████        | 2501/4016 [38:23<09:58,  2.53it/s]

[Batch 2500/4016] [Loss: 0.2725]


Training Epoch 2:  65%|█████████████▌       | 2601/4016 [39:03<09:22,  2.52it/s]

[Batch 2600/4016] [Loss: 0.2915]


Training Epoch 2:  67%|██████████████       | 2701/4016 [39:42<08:43,  2.51it/s]

[Batch 2700/4016] [Loss: 0.3047]


Training Epoch 2:  70%|██████████████▋      | 2801/4016 [40:22<08:09,  2.48it/s]

[Batch 2800/4016] [Loss: 0.3347]


Training Epoch 2:  72%|███████████████▏     | 2901/4016 [41:01<07:17,  2.55it/s]

[Batch 2900/4016] [Loss: 0.2917]


Training Epoch 2:  75%|███████████████▋     | 3001/4016 [41:41<06:35,  2.56it/s]

[Batch 3000/4016] [Loss: 0.3059]


Training Epoch 2:  77%|████████████████▏    | 3101/4016 [42:20<06:10,  2.47it/s]

[Batch 3100/4016] [Loss: 0.4976]


Training Epoch 2:  80%|████████████████▋    | 3201/4016 [43:00<05:18,  2.56it/s]

[Batch 3200/4016] [Loss: 0.2600]


Training Epoch 2:  82%|█████████████████▎   | 3301/4016 [43:39<04:37,  2.58it/s]

[Batch 3300/4016] [Loss: 0.3877]


Training Epoch 2:  85%|█████████████████▊   | 3401/4016 [44:19<04:08,  2.48it/s]

[Batch 3400/4016] [Loss: 0.4727]


Training Epoch 2:  87%|██████████████████▎  | 3501/4016 [44:59<03:19,  2.58it/s]

[Batch 3500/4016] [Loss: 0.2278]


Training Epoch 2:  90%|██████████████████▊  | 3601/4016 [45:40<02:50,  2.44it/s]

[Batch 3600/4016] [Loss: 0.2129]


Training Epoch 2:  92%|███████████████████▎ | 3701/4016 [46:20<02:02,  2.57it/s]

[Batch 3700/4016] [Loss: 0.1749]


Training Epoch 2:  95%|███████████████████▉ | 3801/4016 [46:59<01:23,  2.58it/s]

[Batch 3800/4016] [Loss: 0.1858]


Training Epoch 2:  97%|████████████████████▍| 3901/4016 [47:38<00:45,  2.50it/s]

[Batch 3900/4016] [Loss: 0.4187]


Training Epoch 2: 100%|████████████████████▉| 4001/4016 [48:18<00:06,  2.43it/s]

[Batch 4000/4016] [Loss: 0.2129]


Training Epoch 2: 100%|█████████████████████| 4016/4016 [48:24<00:00,  1.38it/s]


Epoch 1 finished


Training Epoch 3:   0%|                        | 1/4016 [00:00<27:31,  2.43it/s]

[Batch 0/4016] [Loss: 0.5103]


Training Epoch 3:   3%|▌                     | 101/4016 [00:40<25:38,  2.54it/s]

[Batch 100/4016] [Loss: 0.4460]


Training Epoch 3:   5%|█                     | 201/4016 [01:20<24:58,  2.55it/s]

[Batch 200/4016] [Loss: 0.3391]


Training Epoch 3:   7%|█▋                    | 301/4016 [01:59<24:00,  2.58it/s]

[Batch 300/4016] [Loss: 0.4492]


Training Epoch 3:  10%|██▏                   | 401/4016 [02:39<24:31,  2.46it/s]

[Batch 400/4016] [Loss: 0.5796]


Training Epoch 3:  12%|██▋                   | 501/4016 [03:19<23:36,  2.48it/s]

[Batch 500/4016] [Loss: 0.7188]


Training Epoch 3:  15%|███▎                  | 601/4016 [03:59<21:59,  2.59it/s]

[Batch 600/4016] [Loss: 0.4055]


Training Epoch 3:  17%|███▊                  | 701/4016 [04:39<21:43,  2.54it/s]

[Batch 700/4016] [Loss: 0.2029]


Training Epoch 3:  20%|████▍                 | 801/4016 [05:18<22:09,  2.42it/s]

[Batch 800/4016] [Loss: 0.4446]


Training Epoch 3:  22%|████▉                 | 901/4016 [05:59<21:14,  2.44it/s]

[Batch 900/4016] [Loss: 0.2065]


Training Epoch 3:  25%|█████▏               | 1001/4016 [06:39<19:53,  2.53it/s]

[Batch 1000/4016] [Loss: 0.4026]


Training Epoch 3:  27%|█████▊               | 1101/4016 [07:18<19:01,  2.55it/s]

[Batch 1100/4016] [Loss: 0.2344]


Training Epoch 3:  30%|██████▎              | 1201/4016 [07:58<18:03,  2.60it/s]

[Batch 1200/4016] [Loss: 0.3364]


Training Epoch 3:  32%|██████▊              | 1301/4016 [08:37<17:59,  2.51it/s]

[Batch 1300/4016] [Loss: 0.4104]


Training Epoch 3:  35%|███████▎             | 1401/4016 [09:19<19:08,  2.28it/s]

[Batch 1400/4016] [Loss: 0.4058]


Training Epoch 3:  37%|███████▊             | 1501/4016 [10:00<16:54,  2.48it/s]

[Batch 1500/4016] [Loss: 0.4766]


Training Epoch 3:  40%|████████▎            | 1601/4016 [10:40<15:53,  2.53it/s]

[Batch 1600/4016] [Loss: 0.3416]


Training Epoch 3:  42%|████████▉            | 1701/4016 [11:21<15:03,  2.56it/s]

[Batch 1700/4016] [Loss: 0.2681]


Training Epoch 3:  45%|█████████▍           | 1801/4016 [12:01<14:30,  2.54it/s]

[Batch 1800/4016] [Loss: 0.2080]


Training Epoch 3:  47%|█████████▉           | 1901/4016 [12:40<14:13,  2.48it/s]

[Batch 1900/4016] [Loss: 0.5874]


Training Epoch 3:  50%|██████████▍          | 2001/4016 [13:20<13:04,  2.57it/s]

[Batch 2000/4016] [Loss: 0.2852]


Training Epoch 3:  52%|██████████▉          | 2101/4016 [13:59<12:06,  2.63it/s]

[Batch 2100/4016] [Loss: 0.2240]


Training Epoch 3:  55%|███████████▌         | 2201/4016 [14:38<11:40,  2.59it/s]

[Batch 2200/4016] [Loss: 0.4124]


Training Epoch 3:  57%|████████████         | 2301/4016 [15:17<11:02,  2.59it/s]

[Batch 2300/4016] [Loss: 0.2163]


Training Epoch 3:  60%|████████████▌        | 2401/4016 [15:56<10:47,  2.49it/s]

[Batch 2400/4016] [Loss: 0.3806]


Training Epoch 3:  62%|█████████████        | 2501/4016 [16:35<09:39,  2.61it/s]

[Batch 2500/4016] [Loss: 0.2983]


Training Epoch 3:  65%|█████████████▌       | 2601/4016 [17:14<09:06,  2.59it/s]

[Batch 2600/4016] [Loss: 0.3535]


Training Epoch 3:  67%|██████████████       | 2701/4016 [17:53<08:34,  2.56it/s]

[Batch 2700/4016] [Loss: 0.3367]


Training Epoch 3:  70%|██████████████▋      | 2801/4016 [18:33<08:02,  2.52it/s]

[Batch 2800/4016] [Loss: 0.2690]


Training Epoch 3:  72%|███████████████▏     | 2901/4016 [19:12<07:12,  2.58it/s]

[Batch 2900/4016] [Loss: 0.2881]


Training Epoch 3:  75%|███████████████▋     | 3001/4016 [19:51<06:29,  2.61it/s]

[Batch 3000/4016] [Loss: 0.4038]


Training Epoch 3:  77%|████████████████▏    | 3101/4016 [20:30<05:51,  2.60it/s]

[Batch 3100/4016] [Loss: 0.2010]


Training Epoch 3:  80%|████████████████▋    | 3201/4016 [21:10<05:29,  2.47it/s]

[Batch 3200/4016] [Loss: 0.3169]


Training Epoch 3:  82%|█████████████████▎   | 3301/4016 [21:49<04:46,  2.49it/s]

[Batch 3300/4016] [Loss: 0.3745]


Training Epoch 3:  85%|█████████████████▊   | 3401/4016 [22:28<03:57,  2.59it/s]

[Batch 3400/4016] [Loss: 0.2959]


Training Epoch 3:  87%|██████████████████▎  | 3501/4016 [23:08<03:17,  2.61it/s]

[Batch 3500/4016] [Loss: 0.3989]


Training Epoch 3:  90%|██████████████████▊  | 3601/4016 [23:47<02:39,  2.60it/s]

[Batch 3600/4016] [Loss: 0.3364]


Training Epoch 3:  92%|███████████████████▎ | 3701/4016 [24:26<02:02,  2.57it/s]

[Batch 3700/4016] [Loss: 0.4727]


Training Epoch 3:  95%|███████████████████▉ | 3801/4016 [25:06<01:28,  2.43it/s]

[Batch 3800/4016] [Loss: 0.3669]


Training Epoch 3:  97%|████████████████████▍| 3901/4016 [25:45<00:44,  2.57it/s]

[Batch 3900/4016] [Loss: 0.2092]


Training Epoch 3: 100%|████████████████████▉| 4001/4016 [26:24<00:05,  2.56it/s]

[Batch 4000/4016] [Loss: 0.1622]


Training Epoch 3: 100%|█████████████████████| 4016/4016 [26:30<00:00,  2.53it/s]


Epoch 2 finished


Training Epoch 4:   0%|                        | 1/4016 [00:00<27:20,  2.45it/s]

[Batch 0/4016] [Loss: 0.3308]


Training Epoch 4:   3%|▌                     | 101/4016 [00:40<26:23,  2.47it/s]

[Batch 100/4016] [Loss: 0.3894]


Training Epoch 4:   5%|█                     | 201/4016 [01:19<25:35,  2.48it/s]

[Batch 200/4016] [Loss: 0.3630]


Training Epoch 4:   7%|█▋                    | 301/4016 [01:58<24:11,  2.56it/s]

[Batch 300/4016] [Loss: 0.1675]


Training Epoch 4:  10%|██▏                   | 401/4016 [02:37<22:57,  2.62it/s]

[Batch 400/4016] [Loss: 0.1577]


Training Epoch 4:  12%|██▋                   | 501/4016 [03:17<22:41,  2.58it/s]

[Batch 500/4016] [Loss: 0.3442]


Training Epoch 4:  15%|███▎                  | 601/4016 [03:56<22:29,  2.53it/s]

[Batch 600/4016] [Loss: 0.3625]


Training Epoch 4:  17%|███▊                  | 701/4016 [04:35<21:20,  2.59it/s]

[Batch 700/4016] [Loss: 0.2329]


Training Epoch 4:  20%|████▍                 | 801/4016 [05:14<20:41,  2.59it/s]

[Batch 800/4016] [Loss: 0.5103]


Training Epoch 4:  22%|████▉                 | 901/4016 [05:53<19:41,  2.64it/s]

[Batch 900/4016] [Loss: 0.2281]


Training Epoch 4:  25%|█████▏               | 1001/4016 [06:32<19:22,  2.59it/s]

[Batch 1000/4016] [Loss: 0.2834]


Training Epoch 4:  27%|█████▊               | 1101/4016 [07:11<19:19,  2.52it/s]

[Batch 1100/4016] [Loss: 0.3586]


Training Epoch 4:  30%|██████▎              | 1201/4016 [07:51<18:27,  2.54it/s]

[Batch 1200/4016] [Loss: 0.3730]


Training Epoch 4:  32%|██████▊              | 1301/4016 [08:30<17:51,  2.53it/s]

[Batch 1300/4016] [Loss: 0.3167]


Training Epoch 4:  35%|███████▎             | 1401/4016 [09:09<16:36,  2.62it/s]

[Batch 1400/4016] [Loss: 0.4255]


Training Epoch 4:  37%|███████▊             | 1501/4016 [09:48<16:16,  2.57it/s]

[Batch 1500/4016] [Loss: 0.3481]


Training Epoch 4:  40%|████████▎            | 1601/4016 [10:27<16:37,  2.42it/s]

[Batch 1600/4016] [Loss: 0.2820]


Training Epoch 4:  42%|████████▉            | 1701/4016 [11:06<14:50,  2.60it/s]

[Batch 1700/4016] [Loss: 0.4902]


Training Epoch 4:  45%|█████████▍           | 1801/4016 [11:45<14:08,  2.61it/s]

[Batch 1800/4016] [Loss: 0.4297]


Training Epoch 4:  47%|█████████▉           | 1901/4016 [12:24<13:31,  2.61it/s]

[Batch 1900/4016] [Loss: 0.3608]


Training Epoch 4:  50%|██████████▍          | 2001/4016 [13:03<13:16,  2.53it/s]

[Batch 2000/4016] [Loss: 0.3679]


Training Epoch 4:  52%|██████████▉          | 2101/4016 [13:43<13:59,  2.28it/s]

[Batch 2100/4016] [Loss: 0.1792]


Training Epoch 4:  55%|███████████▌         | 2201/4016 [14:22<11:52,  2.55it/s]

[Batch 2200/4016] [Loss: 0.3462]


Training Epoch 4:  57%|████████████         | 2301/4016 [15:00<10:51,  2.63it/s]

[Batch 2300/4016] [Loss: 0.3560]


Training Epoch 4:  60%|████████████▌        | 2401/4016 [15:39<10:21,  2.60it/s]

[Batch 2400/4016] [Loss: 0.3540]


Training Epoch 4:  62%|█████████████        | 2501/4016 [16:19<09:56,  2.54it/s]

[Batch 2500/4016] [Loss: 0.3501]


Training Epoch 4:  65%|█████████████▌       | 2601/4016 [16:58<09:36,  2.45it/s]

[Batch 2600/4016] [Loss: 0.3293]


Training Epoch 4:  67%|██████████████       | 2701/4016 [17:37<08:32,  2.57it/s]

[Batch 2700/4016] [Loss: 0.2350]


Training Epoch 4:  70%|██████████████▋      | 2801/4016 [18:16<07:47,  2.60it/s]

[Batch 2800/4016] [Loss: 0.2151]


Training Epoch 4:  72%|███████████████▏     | 2901/4016 [18:55<07:09,  2.59it/s]

[Batch 2900/4016] [Loss: 0.2228]


Training Epoch 4:  75%|███████████████▋     | 3001/4016 [19:34<06:39,  2.54it/s]

[Batch 3000/4016] [Loss: 0.2145]


Training Epoch 4:  77%|████████████████▏    | 3101/4016 [20:13<05:50,  2.61it/s]

[Batch 3100/4016] [Loss: 0.2205]


Training Epoch 4:  80%|████████████████▋    | 3201/4016 [20:52<05:20,  2.55it/s]

[Batch 3200/4016] [Loss: 0.5776]


Training Epoch 4:  82%|█████████████████▎   | 3301/4016 [21:31<04:33,  2.61it/s]

[Batch 3300/4016] [Loss: 0.2673]


Training Epoch 4:  85%|█████████████████▊   | 3401/4016 [22:10<03:59,  2.56it/s]

[Batch 3400/4016] [Loss: 0.3660]


Training Epoch 4:  87%|██████████████████▎  | 3501/4016 [22:50<03:26,  2.50it/s]

[Batch 3500/4016] [Loss: 0.3604]


Training Epoch 4:  90%|██████████████████▊  | 3601/4016 [23:30<02:57,  2.34it/s]

[Batch 3600/4016] [Loss: 0.2998]


Training Epoch 4:  92%|███████████████████▎ | 3701/4016 [24:14<02:07,  2.47it/s]

[Batch 3700/4016] [Loss: 0.2394]


Training Epoch 4:  95%|███████████████████▉ | 3801/4016 [24:55<01:29,  2.41it/s]

[Batch 3800/4016] [Loss: 0.3042]


Training Epoch 4:  97%|████████████████████▍| 3901/4016 [25:36<00:49,  2.32it/s]

[Batch 3900/4016] [Loss: 0.3623]


Training Epoch 4: 100%|████████████████████▉| 4001/4016 [26:17<00:06,  2.47it/s]

[Batch 4000/4016] [Loss: 0.1656]


Training Epoch 4: 100%|█████████████████████| 4016/4016 [26:23<00:00,  2.54it/s]


Epoch 3 finished


Training Epoch 5:   0%|                        | 1/4016 [00:00<26:17,  2.54it/s]

[Batch 0/4016] [Loss: 0.3892]


Training Epoch 5:   3%|▌                     | 101/4016 [00:42<26:50,  2.43it/s]

[Batch 100/4016] [Loss: 0.4651]


Training Epoch 5:   5%|█                     | 201/4016 [01:23<25:47,  2.47it/s]

[Batch 200/4016] [Loss: 0.1376]


Training Epoch 5:   7%|█▋                    | 301/4016 [02:05<29:19,  2.11it/s]

[Batch 300/4016] [Loss: 0.4133]


Training Epoch 5:  10%|██▏                   | 401/4016 [02:47<24:07,  2.50it/s]

[Batch 400/4016] [Loss: 0.3289]


Training Epoch 5:  12%|██▋                   | 501/4016 [03:28<22:51,  2.56it/s]

[Batch 500/4016] [Loss: 0.2179]


Training Epoch 5:  15%|███▎                  | 601/4016 [04:08<23:55,  2.38it/s]

[Batch 600/4016] [Loss: 0.3167]


Training Epoch 5:  17%|███▊                  | 701/4016 [04:48<21:47,  2.54it/s]

[Batch 700/4016] [Loss: 0.7231]


Training Epoch 5:  20%|████▍                 | 801/4016 [05:29<23:35,  2.27it/s]

[Batch 800/4016] [Loss: 0.4485]


Training Epoch 5:  22%|████▉                 | 901/4016 [06:10<20:26,  2.54it/s]

[Batch 900/4016] [Loss: 0.1642]


Training Epoch 5:  25%|█████▏               | 1001/4016 [06:49<19:38,  2.56it/s]

[Batch 1000/4016] [Loss: 0.2771]


Training Epoch 5:  27%|█████▊               | 1101/4016 [07:29<19:26,  2.50it/s]

[Batch 1100/4016] [Loss: 0.4080]


Training Epoch 5:  30%|██████▎              | 1201/4016 [08:10<18:40,  2.51it/s]

[Batch 1200/4016] [Loss: 0.5659]


Training Epoch 5:  32%|██████▊              | 1301/4016 [08:50<18:31,  2.44it/s]

[Batch 1300/4016] [Loss: 0.1896]


Training Epoch 5:  35%|███████▎             | 1401/4016 [09:33<19:47,  2.20it/s]

[Batch 1400/4016] [Loss: 0.2783]


Training Epoch 5:  37%|███████▊             | 1501/4016 [10:15<17:52,  2.34it/s]

[Batch 1500/4016] [Loss: 0.2954]


Training Epoch 5:  40%|████████▎            | 1601/4016 [10:57<16:03,  2.51it/s]

[Batch 1600/4016] [Loss: 0.2954]


Training Epoch 5:  42%|████████▉            | 1701/4016 [11:38<16:00,  2.41it/s]

[Batch 1700/4016] [Loss: 0.4194]


Training Epoch 5:  45%|█████████▍           | 1801/4016 [12:20<17:15,  2.14it/s]

[Batch 1800/4016] [Loss: 0.4165]


Training Epoch 5:  47%|█████████▉           | 1901/4016 [13:02<14:57,  2.36it/s]

[Batch 1900/4016] [Loss: 0.3389]


Training Epoch 5:  50%|██████████▍          | 2001/4016 [13:44<13:31,  2.48it/s]

[Batch 2000/4016] [Loss: 0.3423]


Training Epoch 5:  52%|██████████▉          | 2101/4016 [14:26<12:40,  2.52it/s]

[Batch 2100/4016] [Loss: 0.4871]


Training Epoch 5:  55%|███████████▌         | 2201/4016 [15:08<13:44,  2.20it/s]

[Batch 2200/4016] [Loss: 0.3594]


Training Epoch 5:  57%|████████████         | 2301/4016 [15:48<11:17,  2.53it/s]

[Batch 2300/4016] [Loss: 0.3777]


Training Epoch 5:  60%|████████████▌        | 2401/4016 [16:30<10:43,  2.51it/s]

[Batch 2400/4016] [Loss: 0.2443]


Training Epoch 5:  62%|█████████████        | 2501/4016 [17:13<10:00,  2.52it/s]

[Batch 2500/4016] [Loss: 0.4192]


Training Epoch 5:  65%|█████████████▌       | 2601/4016 [17:53<09:29,  2.48it/s]

[Batch 2600/4016] [Loss: 0.3521]


Training Epoch 5:  67%|██████████████       | 2701/4016 [18:35<10:27,  2.10it/s]

[Batch 2700/4016] [Loss: 0.3274]


Training Epoch 5:  70%|██████████████▋      | 2801/4016 [19:17<08:00,  2.53it/s]

[Batch 2800/4016] [Loss: 0.4487]


Training Epoch 5:  72%|███████████████▏     | 2901/4016 [19:58<07:25,  2.50it/s]

[Batch 2900/4016] [Loss: 0.3818]


Training Epoch 5:  75%|███████████████▋     | 3001/4016 [20:39<06:47,  2.49it/s]

[Batch 3000/4016] [Loss: 0.2913]


Training Epoch 5:  77%|████████████████▏    | 3101/4016 [21:20<06:14,  2.44it/s]

[Batch 3100/4016] [Loss: 0.3018]


Training Epoch 5:  80%|████████████████▋    | 3201/4016 [22:01<06:03,  2.24it/s]

[Batch 3200/4016] [Loss: 0.1577]


Training Epoch 5:  82%|█████████████████▎   | 3301/4016 [22:42<04:32,  2.62it/s]

[Batch 3300/4016] [Loss: 0.3340]


Training Epoch 5:  85%|█████████████████▊   | 3401/4016 [23:22<04:01,  2.55it/s]

[Batch 3400/4016] [Loss: 0.3843]


Training Epoch 5:  87%|██████████████████▎  | 3501/4016 [24:01<03:17,  2.61it/s]

[Batch 3500/4016] [Loss: 0.5708]


Training Epoch 5:  90%|██████████████████▊  | 3601/4016 [24:41<02:46,  2.49it/s]

[Batch 3600/4016] [Loss: 0.4072]


Training Epoch 5:  92%|███████████████████▎ | 3701/4016 [25:21<02:12,  2.37it/s]

[Batch 3700/4016] [Loss: 0.2664]


Training Epoch 5:  95%|███████████████████▉ | 3801/4016 [26:00<01:22,  2.62it/s]

[Batch 3800/4016] [Loss: 0.3018]


Training Epoch 5:  97%|████████████████████▍| 3901/4016 [26:40<00:44,  2.57it/s]

[Batch 3900/4016] [Loss: 0.2732]


Training Epoch 5: 100%|████████████████████▉| 4001/4016 [27:19<00:05,  2.60it/s]

[Batch 4000/4016] [Loss: 0.4614]


Training Epoch 5: 100%|█████████████████████| 4016/4016 [27:25<00:00,  2.44it/s]


Epoch 4 finished


Training Epoch 6:   0%|                        | 1/4016 [00:00<25:50,  2.59it/s]

[Batch 0/4016] [Loss: 0.3694]


Training Epoch 6:   3%|▌                     | 101/4016 [00:40<26:19,  2.48it/s]

[Batch 100/4016] [Loss: 0.3958]


Training Epoch 6:   5%|█                     | 201/4016 [01:19<24:16,  2.62it/s]

[Batch 200/4016] [Loss: 0.2935]


Training Epoch 6:   7%|█▋                    | 301/4016 [01:59<24:02,  2.58it/s]

[Batch 300/4016] [Loss: 0.3679]


Training Epoch 6:  10%|██▏                   | 401/4016 [02:38<24:26,  2.47it/s]

[Batch 400/4016] [Loss: 0.4497]


Training Epoch 6:  12%|██▋                   | 501/4016 [03:19<26:28,  2.21it/s]

[Batch 500/4016] [Loss: 0.3052]


Training Epoch 6:  15%|███▎                  | 601/4016 [04:00<22:43,  2.50it/s]

[Batch 600/4016] [Loss: 0.3369]


Training Epoch 6:  17%|███▊                  | 701/4016 [04:42<23:11,  2.38it/s]

[Batch 700/4016] [Loss: 0.4966]


Training Epoch 6:  20%|████▍                 | 801/4016 [05:23<22:07,  2.42it/s]

[Batch 800/4016] [Loss: 0.4473]


Training Epoch 6:  22%|████▉                 | 901/4016 [06:04<20:20,  2.55it/s]

[Batch 900/4016] [Loss: 0.3298]


Training Epoch 6:  25%|█████▏               | 1001/4016 [06:45<22:03,  2.28it/s]

[Batch 1000/4016] [Loss: 0.3276]


Training Epoch 6:  27%|█████▊               | 1101/4016 [07:26<19:08,  2.54it/s]

[Batch 1100/4016] [Loss: 0.4204]


Training Epoch 6:  30%|██████▎              | 1201/4016 [08:06<18:46,  2.50it/s]

[Batch 1200/4016] [Loss: 0.2778]


Training Epoch 6:  32%|██████▊              | 1301/4016 [08:48<18:12,  2.49it/s]

[Batch 1300/4016] [Loss: 0.2040]


Training Epoch 6:  35%|███████▎             | 1401/4016 [09:29<17:30,  2.49it/s]

[Batch 1400/4016] [Loss: 0.2917]


Training Epoch 6:  37%|███████▊             | 1501/4016 [10:10<17:33,  2.39it/s]

[Batch 1500/4016] [Loss: 0.1641]


Training Epoch 6:  40%|████████▎            | 1601/4016 [10:50<15:41,  2.56it/s]

[Batch 1600/4016] [Loss: 0.2495]


Training Epoch 6:  42%|████████▉            | 1701/4016 [11:31<15:14,  2.53it/s]

[Batch 1700/4016] [Loss: 0.2532]


Training Epoch 6:  45%|█████████▍           | 1801/4016 [12:12<14:44,  2.50it/s]

[Batch 1800/4016] [Loss: 0.2399]


Training Epoch 6:  47%|█████████▉           | 1901/4016 [12:52<14:23,  2.45it/s]

[Batch 1900/4016] [Loss: 0.2245]


Training Epoch 6:  50%|██████████▍          | 2001/4016 [13:33<13:40,  2.46it/s]

[Batch 2000/4016] [Loss: 0.2983]


Training Epoch 6:  52%|██████████▉          | 2101/4016 [14:13<12:22,  2.58it/s]

[Batch 2100/4016] [Loss: 0.3035]


Training Epoch 6:  55%|███████████▌         | 2201/4016 [14:53<11:52,  2.55it/s]

[Batch 2200/4016] [Loss: 0.3640]


Training Epoch 6:  57%|████████████         | 2301/4016 [15:33<11:09,  2.56it/s]

[Batch 2300/4016] [Loss: 0.3501]


Training Epoch 6:  60%|████████████▌        | 2401/4016 [16:14<11:17,  2.38it/s]

[Batch 2400/4016] [Loss: 0.4468]


Training Epoch 6:  62%|█████████████        | 2501/4016 [16:55<10:17,  2.45it/s]

[Batch 2500/4016] [Loss: 0.6060]


Training Epoch 6:  65%|█████████████▌       | 2601/4016 [17:34<09:01,  2.61it/s]

[Batch 2600/4016] [Loss: 0.2969]


Training Epoch 6:  67%|██████████████       | 2701/4016 [18:14<08:33,  2.56it/s]

[Batch 2700/4016] [Loss: 0.3750]


Training Epoch 6:  70%|██████████████▋      | 2801/4016 [18:54<07:57,  2.55it/s]

[Batch 2800/4016] [Loss: 0.4514]


Training Epoch 6:  72%|███████████████▏     | 2901/4016 [19:34<08:11,  2.27it/s]

[Batch 2900/4016] [Loss: 0.2881]


Training Epoch 6:  75%|███████████████▋     | 3001/4016 [20:14<06:31,  2.59it/s]

[Batch 3000/4016] [Loss: 0.2460]


Training Epoch 6:  77%|████████████████▏    | 3101/4016 [20:53<05:51,  2.60it/s]

[Batch 3100/4016] [Loss: 0.3398]


Training Epoch 6:  80%|████████████████▋    | 3201/4016 [21:33<05:24,  2.51it/s]

[Batch 3200/4016] [Loss: 0.3145]


Training Epoch 6:  82%|█████████████████▎   | 3301/4016 [22:13<04:43,  2.52it/s]

[Batch 3300/4016] [Loss: 0.4827]


Training Epoch 6:  85%|█████████████████▊   | 3401/4016 [22:53<04:38,  2.21it/s]

[Batch 3400/4016] [Loss: 0.4731]


Training Epoch 6:  87%|██████████████████▎  | 3501/4016 [23:33<03:16,  2.62it/s]

[Batch 3500/4016] [Loss: 0.3538]


Training Epoch 6:  90%|██████████████████▊  | 3601/4016 [24:13<02:46,  2.50it/s]

[Batch 3600/4016] [Loss: 0.2488]


Training Epoch 6:  92%|███████████████████▎ | 3701/4016 [24:54<02:07,  2.46it/s]

[Batch 3700/4016] [Loss: 0.4438]


Training Epoch 6:  95%|███████████████████▉ | 3801/4016 [25:36<01:31,  2.36it/s]

[Batch 3800/4016] [Loss: 0.3389]


Training Epoch 6:  97%|████████████████████▍| 3901/4016 [26:17<00:49,  2.35it/s]

[Batch 3900/4016] [Loss: 0.2593]


Training Epoch 6: 100%|████████████████████▉| 4001/4016 [26:59<00:06,  2.26it/s]

[Batch 4000/4016] [Loss: 0.2463]


Training Epoch 6: 100%|█████████████████████| 4016/4016 [27:05<00:00,  2.47it/s]


Epoch 5 finished


Training Epoch 7:   0%|                        | 1/4016 [00:00<27:21,  2.45it/s]

[Batch 0/4016] [Loss: 0.4814]


Training Epoch 7:   3%|▌                     | 101/4016 [00:41<26:56,  2.42it/s]

[Batch 100/4016] [Loss: 0.2964]


Training Epoch 7:   5%|█                     | 201/4016 [01:23<27:13,  2.34it/s]

[Batch 200/4016] [Loss: 0.1714]


Training Epoch 7:   7%|█▋                    | 301/4016 [02:04<25:38,  2.42it/s]

[Batch 300/4016] [Loss: 0.2222]


Training Epoch 7:  10%|██▏                   | 401/4016 [02:46<23:45,  2.54it/s]

[Batch 400/4016] [Loss: 0.3267]


Training Epoch 7:  12%|██▋                   | 501/4016 [03:27<23:50,  2.46it/s]

[Batch 500/4016] [Loss: 0.4722]


Training Epoch 7:  15%|███▎                  | 601/4016 [04:08<22:59,  2.48it/s]

[Batch 600/4016] [Loss: 0.2783]


Training Epoch 7:  17%|███▊                  | 701/4016 [04:50<24:16,  2.28it/s]

[Batch 700/4016] [Loss: 0.4314]


Training Epoch 7:  20%|████▍                 | 801/4016 [05:32<22:14,  2.41it/s]

[Batch 800/4016] [Loss: 0.2825]


Training Epoch 7:  22%|████▉                 | 901/4016 [06:13<20:32,  2.53it/s]

[Batch 900/4016] [Loss: 0.2876]


Training Epoch 7:  25%|█████▏               | 1001/4016 [06:55<20:28,  2.45it/s]

[Batch 1000/4016] [Loss: 0.3289]


Training Epoch 7:  27%|█████▊               | 1101/4016 [07:36<19:56,  2.44it/s]

[Batch 1100/4016] [Loss: 0.3015]


Training Epoch 7:  30%|██████▎              | 1201/4016 [08:17<21:42,  2.16it/s]

[Batch 1200/4016] [Loss: 0.3059]


Training Epoch 7:  32%|██████▊              | 1301/4016 [08:59<18:00,  2.51it/s]

[Batch 1300/4016] [Loss: 0.2308]


Training Epoch 7:  35%|███████▎             | 1401/4016 [09:40<17:31,  2.49it/s]

[Batch 1400/4016] [Loss: 0.3354]


Training Epoch 7:  37%|███████▊             | 1501/4016 [10:21<17:16,  2.43it/s]

[Batch 1500/4016] [Loss: 0.4075]


Training Epoch 7:  40%|████████▎            | 1601/4016 [11:03<16:31,  2.44it/s]

[Batch 1600/4016] [Loss: 0.2207]


Training Epoch 7:  42%|████████▉            | 1701/4016 [11:44<17:12,  2.24it/s]

[Batch 1700/4016] [Loss: 0.2632]


Training Epoch 7:  45%|█████████▍           | 1801/4016 [12:26<14:37,  2.52it/s]

[Batch 1800/4016] [Loss: 0.3496]


Training Epoch 7:  47%|█████████▉           | 1901/4016 [13:07<14:10,  2.49it/s]

[Batch 1900/4016] [Loss: 0.3372]


Training Epoch 7:  50%|██████████▍          | 2001/4016 [13:48<13:39,  2.46it/s]

[Batch 2000/4016] [Loss: 0.3997]


Training Epoch 7:  52%|██████████▉          | 2101/4016 [14:30<13:30,  2.36it/s]

[Batch 2100/4016] [Loss: 0.5308]


Training Epoch 7:  55%|███████████▌         | 2201/4016 [15:11<12:40,  2.39it/s]

[Batch 2200/4016] [Loss: 0.4392]


Training Epoch 7:  57%|████████████         | 2301/4016 [15:53<11:18,  2.53it/s]

[Batch 2300/4016] [Loss: 0.3562]


Training Epoch 7:  60%|████████████▌        | 2401/4016 [16:34<10:49,  2.49it/s]

[Batch 2400/4016] [Loss: 0.2225]


Training Epoch 7:  62%|█████████████        | 2501/4016 [17:16<10:18,  2.45it/s]

[Batch 2500/4016] [Loss: 0.2800]


Training Epoch 7:  65%|█████████████▌       | 2601/4016 [17:57<10:10,  2.32it/s]

[Batch 2600/4016] [Loss: 0.4263]


Training Epoch 7:  67%|██████████████       | 2701/4016 [18:39<09:01,  2.43it/s]

[Batch 2700/4016] [Loss: 0.3657]


Training Epoch 7:  70%|██████████████▋      | 2801/4016 [19:20<07:59,  2.53it/s]

[Batch 2800/4016] [Loss: 0.4333]


Training Epoch 7:  72%|███████████████▏     | 2901/4016 [20:01<07:34,  2.45it/s]

[Batch 2900/4016] [Loss: 0.2810]


Training Epoch 7:  75%|███████████████▋     | 3001/4016 [20:43<06:51,  2.47it/s]

[Batch 3000/4016] [Loss: 0.2991]


Training Epoch 7:  77%|████████████████▏    | 3101/4016 [21:24<06:49,  2.24it/s]

[Batch 3100/4016] [Loss: 0.5059]


Training Epoch 7:  80%|████████████████▋    | 3201/4016 [22:06<05:25,  2.50it/s]

[Batch 3200/4016] [Loss: 0.4873]


Training Epoch 7:  82%|█████████████████▎   | 3301/4016 [22:47<04:43,  2.52it/s]

[Batch 3300/4016] [Loss: 0.2018]


Training Epoch 7:  85%|█████████████████▊   | 3401/4016 [23:28<04:10,  2.46it/s]

[Batch 3400/4016] [Loss: 0.4167]


Training Epoch 7:  87%|██████████████████▎  | 3501/4016 [24:10<03:30,  2.44it/s]

[Batch 3500/4016] [Loss: 0.2385]


Training Epoch 7:  90%|██████████████████▊  | 3601/4016 [24:51<03:21,  2.06it/s]

[Batch 3600/4016] [Loss: 0.3289]


Training Epoch 7:  92%|███████████████████▎ | 3701/4016 [25:33<02:04,  2.54it/s]

[Batch 3700/4016] [Loss: 0.3545]


Training Epoch 7:  95%|███████████████████▉ | 3801/4016 [26:13<01:25,  2.51it/s]

[Batch 3800/4016] [Loss: 0.3347]


Training Epoch 7:  97%|████████████████████▍| 3901/4016 [26:55<00:47,  2.41it/s]

[Batch 3900/4016] [Loss: 0.4832]


Training Epoch 7: 100%|████████████████████▉| 4001/4016 [27:36<00:06,  2.40it/s]

[Batch 4000/4016] [Loss: 0.2224]


Training Epoch 7: 100%|█████████████████████| 4016/4016 [27:42<00:00,  2.42it/s]


Epoch 6 finished


Training Epoch 8:   0%|                        | 1/4016 [00:00<27:28,  2.44it/s]

[Batch 0/4016] [Loss: 0.2267]


Training Epoch 8:   3%|▌                     | 101/4016 [00:42<25:44,  2.53it/s]

[Batch 100/4016] [Loss: 0.2395]


Training Epoch 8:   5%|█                     | 201/4016 [01:23<25:35,  2.48it/s]

[Batch 200/4016] [Loss: 0.4478]


Training Epoch 8:   7%|█▋                    | 301/4016 [02:05<25:48,  2.40it/s]

[Batch 300/4016] [Loss: 0.5142]


Training Epoch 8:  10%|██▏                   | 401/4016 [02:46<25:36,  2.35it/s]

[Batch 400/4016] [Loss: 0.2159]


Training Epoch 8:  12%|██▋                   | 501/4016 [03:28<24:23,  2.40it/s]

[Batch 500/4016] [Loss: 0.3057]


Training Epoch 8:  15%|███▎                  | 601/4016 [04:10<22:40,  2.51it/s]

[Batch 600/4016] [Loss: 0.4119]


Training Epoch 8:  17%|███▊                  | 701/4016 [04:51<22:28,  2.46it/s]

[Batch 700/4016] [Loss: 0.5234]


Training Epoch 8:  20%|████▍                 | 801/4016 [05:33<22:04,  2.43it/s]

[Batch 800/4016] [Loss: 0.2622]


Training Epoch 8:  22%|████▉                 | 901/4016 [06:15<23:57,  2.17it/s]

[Batch 900/4016] [Loss: 0.2391]


Training Epoch 8:  25%|█████▏               | 1001/4016 [06:57<21:00,  2.39it/s]

[Batch 1000/4016] [Loss: 0.2896]


Training Epoch 8:  27%|█████▊               | 1101/4016 [07:38<19:14,  2.52it/s]

[Batch 1100/4016] [Loss: 0.3325]


Training Epoch 8:  30%|██████▎              | 1201/4016 [08:19<19:14,  2.44it/s]

[Batch 1200/4016] [Loss: 0.4292]


Training Epoch 8:  32%|██████▊              | 1301/4016 [09:01<18:19,  2.47it/s]

[Batch 1300/4016] [Loss: 0.5068]


Training Epoch 8:  35%|███████▎             | 1401/4016 [09:41<19:10,  2.27it/s]

[Batch 1400/4016] [Loss: 0.2407]


Training Epoch 8:  37%|███████▊             | 1501/4016 [10:21<16:18,  2.57it/s]

[Batch 1500/4016] [Loss: 0.3494]


Training Epoch 8:  40%|████████▎            | 1601/4016 [11:01<15:25,  2.61it/s]

[Batch 1600/4016] [Loss: 0.2603]


Training Epoch 8:  42%|████████▉            | 1701/4016 [11:41<15:12,  2.54it/s]

[Batch 1700/4016] [Loss: 0.3240]


Training Epoch 8:  45%|█████████▍           | 1801/4016 [12:21<15:09,  2.44it/s]

[Batch 1800/4016] [Loss: 0.4019]


Training Epoch 8:  47%|█████████▉           | 1901/4016 [13:01<16:15,  2.17it/s]

[Batch 1900/4016] [Loss: 0.3101]


Training Epoch 8:  50%|██████████▍          | 2001/4016 [13:41<12:50,  2.62it/s]

[Batch 2000/4016] [Loss: 0.2524]


Training Epoch 8:  52%|██████████▉          | 2101/4016 [14:21<12:24,  2.57it/s]

[Batch 2100/4016] [Loss: 0.5342]


Training Epoch 8:  55%|███████████▌         | 2201/4016 [15:01<11:47,  2.56it/s]

[Batch 2200/4016] [Loss: 0.3875]


Training Epoch 8:  57%|████████████         | 2301/4016 [15:41<11:40,  2.45it/s]

[Batch 2300/4016] [Loss: 0.3279]


Training Epoch 8:  60%|████████████▌        | 2401/4016 [16:22<11:23,  2.36it/s]

[Batch 2400/4016] [Loss: 0.1744]


Training Epoch 8:  62%|█████████████        | 2501/4016 [17:01<09:41,  2.61it/s]

[Batch 2500/4016] [Loss: 0.5654]


Training Epoch 8:  65%|█████████████▌       | 2601/4016 [17:41<09:12,  2.56it/s]

[Batch 2600/4016] [Loss: 0.2311]


Training Epoch 8:  67%|██████████████       | 2701/4016 [18:22<08:36,  2.55it/s]

[Batch 2700/4016] [Loss: 0.4072]


Training Epoch 8:  70%|██████████████▋      | 2801/4016 [19:02<08:23,  2.41it/s]

[Batch 2800/4016] [Loss: 0.3093]


Training Epoch 8:  72%|███████████████▏     | 2901/4016 [19:42<07:06,  2.61it/s]

[Batch 2900/4016] [Loss: 0.7188]


Training Epoch 8:  75%|███████████████▋     | 3001/4016 [20:22<06:31,  2.59it/s]

[Batch 3000/4016] [Loss: 0.4849]


Training Epoch 8:  77%|████████████████▏    | 3101/4016 [21:01<05:52,  2.60it/s]

[Batch 3100/4016] [Loss: 0.6333]


Training Epoch 8:  80%|████████████████▋    | 3201/4016 [21:40<05:14,  2.59it/s]

[Batch 3200/4016] [Loss: 0.2993]


Training Epoch 8:  82%|█████████████████▎   | 3301/4016 [22:20<05:07,  2.32it/s]

[Batch 3300/4016] [Loss: 0.2644]


Training Epoch 8:  85%|█████████████████▊   | 3401/4016 [23:00<03:54,  2.62it/s]

[Batch 3400/4016] [Loss: 0.2134]


Training Epoch 8:  87%|██████████████████▎  | 3501/4016 [23:39<03:16,  2.62it/s]

[Batch 3500/4016] [Loss: 0.2000]


Training Epoch 8:  90%|██████████████████▊  | 3601/4016 [24:19<02:42,  2.56it/s]

[Batch 3600/4016] [Loss: 0.5205]


Training Epoch 8:  92%|███████████████████▎ | 3701/4016 [24:59<02:03,  2.54it/s]

[Batch 3700/4016] [Loss: 0.4727]


Training Epoch 8:  95%|███████████████████▉ | 3801/4016 [25:39<01:39,  2.17it/s]

[Batch 3800/4016] [Loss: 0.4438]


Training Epoch 8:  97%|████████████████████▍| 3901/4016 [26:18<00:43,  2.63it/s]

[Batch 3900/4016] [Loss: 0.2747]


Training Epoch 8: 100%|████████████████████▉| 4001/4016 [26:58<00:05,  2.58it/s]

[Batch 4000/4016] [Loss: 0.4683]


Training Epoch 8: 100%|█████████████████████| 4016/4016 [27:04<00:00,  2.47it/s]


Epoch 7 finished


Training Epoch 9:   0%|                        | 1/4016 [00:00<26:15,  2.55it/s]

[Batch 0/4016] [Loss: 0.5234]


Training Epoch 9:   3%|▌                     | 101/4016 [00:40<26:04,  2.50it/s]

[Batch 100/4016] [Loss: 0.4543]


Training Epoch 9:   5%|█                     | 201/4016 [01:19<25:55,  2.45it/s]

[Batch 200/4016] [Loss: 0.4556]


Training Epoch 9:   7%|█▋                    | 301/4016 [01:59<23:21,  2.65it/s]

[Batch 300/4016] [Loss: 0.5278]


Training Epoch 9:  10%|██▏                   | 401/4016 [02:39<23:05,  2.61it/s]

[Batch 400/4016] [Loss: 0.5542]


Training Epoch 9:  12%|██▋                   | 501/4016 [03:19<23:18,  2.51it/s]

[Batch 500/4016] [Loss: 0.4482]


Training Epoch 9:  15%|███▎                  | 601/4016 [03:59<23:24,  2.43it/s]

[Batch 600/4016] [Loss: 0.3071]


Training Epoch 9:  17%|███▊                  | 701/4016 [04:39<21:58,  2.51it/s]

[Batch 700/4016] [Loss: 0.3430]


Training Epoch 9:  20%|████▍                 | 801/4016 [05:19<20:44,  2.58it/s]

[Batch 800/4016] [Loss: 0.3296]


Training Epoch 9:  22%|████▉                 | 901/4016 [05:59<20:04,  2.59it/s]

[Batch 900/4016] [Loss: 0.4050]


Training Epoch 9:  25%|█████▏               | 1001/4016 [06:40<19:41,  2.55it/s]

[Batch 1000/4016] [Loss: 0.4399]


Training Epoch 9:  27%|█████▊               | 1101/4016 [07:20<20:40,  2.35it/s]

[Batch 1100/4016] [Loss: 0.3188]


Training Epoch 9:  30%|██████▎              | 1201/4016 [08:00<18:42,  2.51it/s]

[Batch 1200/4016] [Loss: 0.4072]


Training Epoch 9:  32%|██████▊              | 1301/4016 [08:40<17:07,  2.64it/s]

[Batch 1300/4016] [Loss: 0.3201]


Training Epoch 9:  35%|███████▎             | 1401/4016 [09:19<16:51,  2.58it/s]

[Batch 1400/4016] [Loss: 0.4463]


Training Epoch 9:  37%|███████▊             | 1501/4016 [09:59<16:25,  2.55it/s]

[Batch 1500/4016] [Loss: 0.3479]


Training Epoch 9:  40%|████████▎            | 1601/4016 [10:39<17:47,  2.26it/s]

[Batch 1600/4016] [Loss: 0.3142]


Training Epoch 9:  42%|████████▉            | 1701/4016 [11:19<15:03,  2.56it/s]

[Batch 1700/4016] [Loss: 0.4390]


Training Epoch 9:  45%|█████████▍           | 1801/4016 [11:58<14:03,  2.63it/s]

[Batch 1800/4016] [Loss: 0.3816]


Training Epoch 9:  47%|█████████▉           | 1901/4016 [12:38<13:45,  2.56it/s]

[Batch 1900/4016] [Loss: 0.3489]


Training Epoch 9:  50%|██████████▍          | 2001/4016 [13:18<13:12,  2.54it/s]

[Batch 2000/4016] [Loss: 0.3306]


Training Epoch 9:  52%|██████████▉          | 2101/4016 [13:58<14:41,  2.17it/s]

[Batch 2100/4016] [Loss: 0.4158]


Training Epoch 9:  55%|███████████▌         | 2201/4016 [14:38<11:38,  2.60it/s]

[Batch 2200/4016] [Loss: 0.5103]


Training Epoch 9:  57%|████████████         | 2301/4016 [15:17<10:59,  2.60it/s]

[Batch 2300/4016] [Loss: 0.3132]


Training Epoch 9:  60%|████████████▌        | 2401/4016 [15:57<10:29,  2.57it/s]

[Batch 2400/4016] [Loss: 0.3882]


Training Epoch 9:  62%|█████████████        | 2501/4016 [16:37<10:05,  2.50it/s]

[Batch 2500/4016] [Loss: 0.4675]


Training Epoch 9:  65%|█████████████▌       | 2601/4016 [17:17<10:38,  2.22it/s]

[Batch 2600/4016] [Loss: 0.5488]


Training Epoch 9:  67%|██████████████       | 2701/4016 [17:57<08:20,  2.63it/s]

[Batch 2700/4016] [Loss: 0.4949]


Training Epoch 9:  70%|██████████████▋      | 2801/4016 [18:36<07:45,  2.61it/s]

[Batch 2800/4016] [Loss: 0.2883]


Training Epoch 9:  72%|███████████████▏     | 2901/4016 [19:16<07:23,  2.52it/s]

[Batch 2900/4016] [Loss: 0.2271]


Training Epoch 9:  75%|███████████████▋     | 3001/4016 [19:56<06:54,  2.45it/s]

[Batch 3000/4016] [Loss: 0.4019]


Training Epoch 9:  77%|████████████████▏    | 3101/4016 [20:36<05:53,  2.59it/s]

[Batch 3100/4016] [Loss: 0.2791]


Training Epoch 9:  80%|████████████████▋    | 3201/4016 [21:16<05:17,  2.57it/s]

[Batch 3200/4016] [Loss: 0.1957]


Training Epoch 9:  82%|█████████████████▎   | 3301/4016 [21:55<04:38,  2.57it/s]

[Batch 3300/4016] [Loss: 0.2180]


Training Epoch 9:  85%|█████████████████▊   | 3401/4016 [22:36<04:02,  2.54it/s]

[Batch 3400/4016] [Loss: 0.3838]


Training Epoch 9:  87%|██████████████████▎  | 3501/4016 [23:17<03:40,  2.34it/s]

[Batch 3500/4016] [Loss: 0.5010]


Training Epoch 9:  90%|██████████████████▊  | 3601/4016 [23:56<02:43,  2.54it/s]

[Batch 3600/4016] [Loss: 0.4783]


Training Epoch 9:  92%|███████████████████▎ | 3701/4016 [24:36<02:00,  2.61it/s]

[Batch 3700/4016] [Loss: 0.4197]


Training Epoch 9:  95%|███████████████████▉ | 3801/4016 [25:16<01:24,  2.53it/s]

[Batch 3800/4016] [Loss: 0.3486]


Training Epoch 9:  97%|████████████████████▍| 3901/4016 [25:57<00:45,  2.54it/s]

[Batch 3900/4016] [Loss: 0.4546]


Training Epoch 9: 100%|████████████████████▉| 4001/4016 [26:37<00:06,  2.26it/s]

[Batch 4000/4016] [Loss: 0.3484]


Training Epoch 9: 100%|█████████████████████| 4016/4016 [26:43<00:00,  2.50it/s]


Epoch 8 finished


Training Epoch 10:   0%|                       | 1/4016 [00:00<26:17,  2.54it/s]

[Batch 0/4016] [Loss: 0.3164]


Training Epoch 10:   3%|▌                    | 101/4016 [00:40<25:00,  2.61it/s]

[Batch 100/4016] [Loss: 0.2390]


Training Epoch 10:   5%|█                    | 201/4016 [01:20<25:00,  2.54it/s]

[Batch 200/4016] [Loss: 0.2590]


Training Epoch 10:   7%|█▌                   | 301/4016 [02:00<24:58,  2.48it/s]

[Batch 300/4016] [Loss: 0.3772]


Training Epoch 10:  10%|██                   | 401/4016 [02:41<27:22,  2.20it/s]

[Batch 400/4016] [Loss: 0.1653]


Training Epoch 10:  12%|██▌                  | 501/4016 [03:21<22:41,  2.58it/s]

[Batch 500/4016] [Loss: 0.2598]


Training Epoch 10:  15%|███▏                 | 601/4016 [04:00<21:34,  2.64it/s]

[Batch 600/4016] [Loss: 0.5288]


Training Epoch 10:  17%|███▋                 | 701/4016 [04:40<21:32,  2.56it/s]

[Batch 700/4016] [Loss: 0.5566]


Training Epoch 10:  20%|████▏                | 801/4016 [05:20<21:55,  2.44it/s]

[Batch 800/4016] [Loss: 0.3726]


Training Epoch 10:  22%|████▋                | 901/4016 [06:00<20:05,  2.58it/s]

[Batch 900/4016] [Loss: 0.3831]


Training Epoch 10:  25%|████▉               | 1001/4016 [06:40<20:02,  2.51it/s]

[Batch 1000/4016] [Loss: 0.6528]


Training Epoch 10:  27%|█████▍              | 1101/4016 [07:20<18:36,  2.61it/s]

[Batch 1100/4016] [Loss: 0.1812]


Training Epoch 10:  30%|█████▉              | 1201/4016 [08:00<18:32,  2.53it/s]

[Batch 1200/4016] [Loss: 0.3430]


Training Epoch 10:  32%|██████▍             | 1301/4016 [08:40<19:07,  2.37it/s]

[Batch 1300/4016] [Loss: 0.4517]


Training Epoch 10:  35%|██████▉             | 1401/4016 [09:19<17:03,  2.55it/s]

[Batch 1400/4016] [Loss: 0.4529]


Training Epoch 10:  37%|███████▍            | 1501/4016 [09:59<15:53,  2.64it/s]

[Batch 1500/4016] [Loss: 0.2810]


Training Epoch 10:  40%|███████▉            | 1601/4016 [10:39<15:28,  2.60it/s]

[Batch 1600/4016] [Loss: 0.1876]


Training Epoch 10:  42%|████████▍           | 1701/4016 [11:19<15:18,  2.52it/s]

[Batch 1700/4016] [Loss: 0.3552]


Training Epoch 10:  45%|████████▉           | 1801/4016 [12:00<16:19,  2.26it/s]

[Batch 1800/4016] [Loss: 0.2128]


Training Epoch 10:  47%|█████████▍          | 1901/4016 [12:40<13:45,  2.56it/s]

[Batch 1900/4016] [Loss: 0.2664]


Training Epoch 10:  50%|█████████▉          | 2001/4016 [13:20<12:46,  2.63it/s]

[Batch 2000/4016] [Loss: 0.3599]


Training Epoch 10:  52%|██████████▍         | 2101/4016 [13:59<12:18,  2.59it/s]

[Batch 2100/4016] [Loss: 0.2695]


Training Epoch 10:  55%|██████████▉         | 2201/4016 [14:40<12:05,  2.50it/s]

[Batch 2200/4016] [Loss: 0.3049]


Training Epoch 10:  57%|███████████▍        | 2301/4016 [15:19<13:26,  2.13it/s]

[Batch 2300/4016] [Loss: 0.2610]


Training Epoch 10:  60%|███████████▉        | 2401/4016 [15:59<10:22,  2.59it/s]

[Batch 2400/4016] [Loss: 0.3193]


Training Epoch 10:  62%|████████████▍       | 2501/4016 [16:39<09:39,  2.62it/s]

[Batch 2500/4016] [Loss: 0.2705]


Training Epoch 10:  65%|████████████▉       | 2601/4016 [17:19<09:35,  2.46it/s]

[Batch 2600/4016] [Loss: 0.2231]


Training Epoch 10:  67%|█████████████▍      | 2701/4016 [17:59<08:45,  2.50it/s]

[Batch 2700/4016] [Loss: 0.5146]


Training Epoch 10:  70%|█████████████▉      | 2801/4016 [18:40<09:18,  2.17it/s]

[Batch 2800/4016] [Loss: 0.4497]


Training Epoch 10:  72%|██████████████▍     | 2901/4016 [19:19<07:06,  2.61it/s]

[Batch 2900/4016] [Loss: 0.5469]


Training Epoch 10:  75%|██████████████▉     | 3001/4016 [19:59<06:27,  2.62it/s]

[Batch 3000/4016] [Loss: 0.2869]


Training Epoch 10:  77%|███████████████▍    | 3101/4016 [20:39<06:01,  2.53it/s]

[Batch 3100/4016] [Loss: 0.2849]


Training Epoch 10:  80%|███████████████▉    | 3201/4016 [21:19<05:35,  2.43it/s]

[Batch 3200/4016] [Loss: 0.1649]


Training Epoch 10:  82%|████████████████▍   | 3301/4016 [21:59<04:33,  2.61it/s]

[Batch 3300/4016] [Loss: 0.3552]


Training Epoch 10:  85%|████████████████▉   | 3401/4016 [22:39<03:53,  2.63it/s]

[Batch 3400/4016] [Loss: 0.3787]


Training Epoch 10:  87%|█████████████████▍  | 3501/4016 [23:19<03:22,  2.55it/s]

[Batch 3500/4016] [Loss: 0.4067]


Training Epoch 10:  90%|█████████████████▉  | 3601/4016 [23:59<02:46,  2.49it/s]

[Batch 3600/4016] [Loss: 0.4578]


Training Epoch 10:  92%|██████████████████▍ | 3701/4016 [24:39<02:15,  2.33it/s]

[Batch 3700/4016] [Loss: 0.3662]


Training Epoch 10:  95%|██████████████████▉ | 3801/4016 [25:19<01:23,  2.59it/s]

[Batch 3800/4016] [Loss: 0.4109]


Training Epoch 10:  97%|███████████████████▍| 3901/4016 [25:59<00:43,  2.63it/s]

[Batch 3900/4016] [Loss: 0.4487]


Training Epoch 10: 100%|███████████████████▉| 4001/4016 [26:39<00:05,  2.59it/s]

[Batch 4000/4016] [Loss: 0.3752]


Training Epoch 10: 100%|████████████████████| 4016/4016 [26:45<00:00,  2.50it/s]


Epoch 9 finished
Training completed ✅


In [17]:
# 1. Load UNet
unet = UNet2DConditionModel.from_pretrained(
    "runwayml/stable-diffusion-v1-5", 
    subfolder="unet"
).to(device)

# 2. Load CLIP text encoder
clip_text_encoder = CLIPTextModel.from_pretrained(
    "openai/clip-vit-large-patch14"
).to(device)
clip_tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")

projection = Projection().to(device)  # Use Projection class definition

# Load the checkpoint
checkpoint = torch.load("final_model_checkpoint.pt", map_location=device)
vae = pipeline.vae.to(device)
scheduler = LMSDiscreteScheduler.from_config(pipeline.scheduler.config)

unet.load_state_dict(checkpoint['unet'])
clip_text_encoder.load_state_dict(checkpoint['clip_text_encoder'])
projection.load_state_dict(checkpoint['projection'])

# Set models to evaluation mode
unet.eval()
clip_text_encoder.eval()
projection.eval()

Cannot initialize model with low cpu memory usage because `accelerate` was not found in the environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install `accelerate` for faster and less memory-intense model loading. You can do so with: 
```
pip install accelerate
```
.
  checkpoint = torch.load("final_model_checkpoint.pt", map_location=device)


Projection(
  (proj): Sequential(
    (0): Linear(in_features=768, out_features=768, bias=True)
    (1): Tanh()
  )
)

### EVALUATION 


In [30]:
print(f"UNet dtype: {next(unet.parameters()).dtype}")
print(f"CLIP dtype: {next(clip_text_encoder.parameters()).dtype}")
print(f"Projection dtype: {next(projection.parameters()).dtype}")
print(f"VAE dtype: {next(vae.parameters()).dtype}")

UNet dtype: torch.float32
CLIP dtype: torch.float32
Projection dtype: torch.float32
VAE dtype: torch.float16


In [40]:
from torch.utils.data import Subset
subset_loader = torch.utils.data.DataLoader(
    Subset(test_dataset, range(500)),  
    batch_size=8,
    shuffle=False,
)

In [None]:
evaluator = Evaluator(device)
evaluator.reset()

# === Determine VAE dtype and latent shape ===
with torch.no_grad():
    sample_batch = next(iter(subset_loader))["image"].to(device)
    vae_dtype = next(vae.parameters()).dtype
    latent = vae.encode(sample_batch.to(dtype=vae_dtype)).latent_dist.sample() * 0.18215
    latent_shape = latent.shape[1:]

# === Evaluation loop ===
with torch.no_grad():
    for batch in tqdm(subset_loader, desc="Evaluating"):
        images = batch["image"].to(device, dtype=vae_dtype)
        texts = batch["caption_text"]

        # === Encode text ===
        inputs = clip_tokenizer(texts, return_tensors="pt", padding=True, truncation=True).to(device)
        text_feats = clip_text_encoder(**inputs).last_hidden_state
        pooled_feats = text_feats.mean(dim=1)
        projected_feats = projection(pooled_feats).to(device=device, dtype=vae_dtype).unsqueeze(1)

        # === Prepare scheduler & latents ===
        latents = torch.randn((images.size(0), *latent_shape), dtype=torch.float32, device=device)
        scheduler.set_timesteps(20)
        latents = latents * scheduler.init_noise_sigma

        # === Denoising loop ===
        for t in scheduler.timesteps:
            t_tensor = torch.tensor([t], device=device).long().expand(images.size(0))
            latents = latents.float()  # ensure float32
            noise_pred = unet(latents, t_tensor, encoder_hidden_states=projected_feats.float()).sample
            latents = scheduler.step(model_output=noise_pred, timestep=t, sample=latents).prev_sample

        # === Decode latents ===
        pred_latents = latents.to(dtype=vae_dtype)
        fake_images = vae.decode(pred_latents / 0.18215).sample
        fake_images = torch.clamp((fake_images + 1.0) / 2.0, 0, 1)

        # === Compute metrics ===
        evaluator.compute_metrics(images, fake_images, texts)

# === Print final scores ===
metrics = evaluator.get_metrics()
print(f"FID: {metrics['fid']:.2f}")
print(f"Inception Score: {metrics['inception_score']:.2f}")
print(f"CLIP Score: {metrics['clip_score']:.2f}")

In [None]:
# === Print final scores ===
metrics =  evaluator.compute_metrics(images, fake_images, texts)
print(f"FID: {metrics['fid']:.2f}")
print(f"Inception Score: {metrics['inception_score']:.2f}")
print(f"CLIP Score: {metrics['clip_score']:.2f}")

In [None]:
import numpy as np
# === Determine VAE dtype and latent shape ===
with torch.no_grad():
    sample_batch = next(iter(subset_loader))["image"].to(device)
    vae_dtype = next(vae.parameters()).dtype
    latent = vae.encode(sample_batch.to(dtype=vae_dtype)).latent_dist.sample() * 0.18215
    latent_shape = latent.shape[1:]

# Custom text for generation
custom_text = "man standing"  # Change this to your desired text

# === Prepare text input ===
with torch.no_grad():
    inputs = clip_tokenizer([custom_text], return_tensors="pt", padding=True, truncation=True).to(device)
    text_feats = clip_text_encoder(**inputs).last_hidden_state
    pooled_feats = text_feats.mean(dim=1)
    projected_feats = projection(pooled_feats).to(device=device, dtype=vae_dtype).unsqueeze(1)

    # === Prepare scheduler & latents ===
    latents = torch.randn((1, *latent_shape), dtype=torch.float32, device=device)
    scheduler.set_timesteps(20)
    latents = latents * scheduler.init_noise_sigma

    # === Denoising loop ===
    for t in scheduler.timesteps:
        t_tensor = torch.tensor([t], device=device).long().expand(1)
        latents = latents.float()  # ensure float32
        
        # Scale model input if scheduler requires it
        if hasattr(scheduler, 'scale_model_input'):
            latents = scheduler.scale_model_input(latents, t)
            
        noise_pred = unet(latents, t_tensor, encoder_hidden_states=projected_feats.float()).sample
        latents = scheduler.step(noise_pred, t, latents).prev_sample

    # === Decode latents ===
    pred_latents = latents.to(dtype=vae_dtype)
    fake_image = vae.decode(pred_latents / 0.18215).sample
    fake_image = torch.clamp((fake_image + 1.0) / 2.0, 0, 1)

# === Convert image for display ===
# Move to CPU and convert to numpy
image_np = fake_image.squeeze(0).permute(1, 2, 0).cpu().numpy()

# Ensure proper dtype and value range
if image_np.dtype != np.float32:
    image_np = image_np.astype(np.float32)
image_np = np.clip(image_np, 0, 1)  # Ensure values are between 0 and 1

# === Plot the generated image ===
plt.figure(figsize=(8, 8))
plt.imshow(image_np)
plt.title(f'Generated Image for: "{custom_text}"')
plt.axis('off')

# Save the figure to avoid display issues
plt.savefig('generated_image.png', bbox_inches='tight', pad_inches=0)
plt.close()

# Display the saved image (alternative method)
from IPython.display import Image, display
display(Image(filename='generated_image.png'))

In [None]:
# Check if MPS is available and set the device accordingly
device = "mps" if torch.backends.mps.is_available() else "cpu"

# Load pre-trained stable diffusion model
pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4")
pipe.to(device)

# Generate an image from a text prompt
prompt = "A futuristic cityscape at sunset"
image = pipe(prompt).images[0]

# Save or display the image
image.save("generated_image.png")
image.show()