Skip to content

Commit

Permalink
Add definitive training script
Browse files Browse the repository at this point in the history
  • Loading branch information
acmo0 committed May 31, 2024
1 parent 7077c35 commit 90576aa
Showing 1 changed file with 2 additions and 1 deletion.
3 changes: 2 additions & 1 deletion train_scripts/UNET_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
# Une fois le UNet créé, on définit une fonction de perte :
loss_fn = torch.nn.MSELoss()
#on va utiliser un optimiseur qui va permettre de minimiser la fonction de perte
optimizer = torch.optim.Adam(parameters, lr=0.001) #lr est le learning rate (j'ai utilisé celui de l'exemple de PyTorch)
optimizer = torch.optim.Adam(parameters, lr=0.0001) #lr est le learning rate (j'ai utilisé celui de l'exemple de PyTorch)



Expand All @@ -42,6 +42,7 @@
outputs = model(noise_image)
loss = loss_fn(img, outputs) #calcule les écarts entre les données du modèle et les données réelles
loss.backward() # calcule les gradients de la perte par rapport aux paramètres du modèle
nn.utils.clip_grad_norm_(model.parameters(), 5.0)
optimizer.step() # met à jour l'optimisation
print(loss)
counter+=1
Expand Down

0 comments on commit 90576aa

Please sign in to comment.